【李沐】十分钟从 PyTorch 转 MXNet

简介: PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。MXNet通过ndarray和 gluon模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法。

PyTorch 是一个纯命令式的深度学习框架。它因为提供简单易懂的编程接口而广受欢迎,而且正在快速的流行开来。例如 Caffe2 最近就并入了 PyTorch。

可能大家不是特别知道的是,MXNet 通过 ndarray 和 gluon 模块提供了非常类似 PyTorch 的编程接口。本文将简单对比如何用这两个框架来实现同样的算法

89e0e8d5de21311740959c69f9ae2fe0258d52af

安装

PyTorch 默认使用 conda 来进行安装,例如

03192dec910f50e049d5fecb3109e8b09f6cdf9b

而 MXNet 更常用的是使用 pip。我们这里使用了 --pre 来安装 nightly 版本

83d9786fd8b0f46bc693765c60c2e0544ec118a7

多维矩阵

对于多维矩阵,PyTorch 沿用了 Torch 的风格称之为 tensor,MXNet 则追随了 NumPy 的称呼 ndarray。下面我们创建一个两维矩阵,其中每个元素初始化成 1。然后每个元素加 1 后打印。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

b472e4f3dada3709d53edf6608ab47f322089ca2

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

28436e80f4887a6ef0ba21b7dedafad23e823410

忽略包名的不一样的话,这里主要的区别是 MXNet 的形状传入参数跟 NumPy 一样需要用括号括起来。

模型训练

下面我们看一个稍微复杂点的例子。这里我们使用一个多层感知机(MLP)来在 MINST 这个数据集上训练一个模型。我们将其分成 4 小块来方便对比。

读取数据

这里我们下载 MNIST 数据集并载入到内存,这样我们之后可以一个一个读取批量。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

a3657b3dbcca68c3b62521edd3f0dd3082a15389

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d50f77b5a334e2bd2072180a29c72c9a74ed18dc

这里的主要区别是 MXNet 使用 transform_first 来表明数据变化是作用在读到的批量的第一个元素,既 MNIST 图片,而不是第二个标号元素。

定义模型

下面我们定义一个只有一个单隐层的 MLP 。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

01818faab6a9ae66f6daae74be169b64c894f344

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

8e95eac5ee43682f1ca882e782de98079be13d9e

我们使用了 Sequential 容器来把层串起来构造神经网络。这里 MXNet 跟 PyTorch 的主要区别是:

8481c8f592b7f349aa84a1de5c171db681516edf 不需要指定输入大小,这个系统会在后面自动推理得到
8481c8f592b7f349aa84a1de5c171db681516edf 全连接和卷积层可以指定激活函数
8481c8f592b7f349aa84a1de5c171db681516edf需要创建一个  name_scope  的域来给每一层附上一个独一无二的名字,这个在之后读写模型时需要
8481c8f592b7f349aa84a1de5c171db681516edf 我们需要显示调用模型初始化函数。


大家知道 Sequential 下只能神经网络只能逐一执行每个层。PyTorch 可以继承 nn.Module 来自定义 forward 如何执行。同样,MXNet 可以继承 nn.Block 来达到类似的效果。

损失函数和优化算法

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch:

483451f3193b8143e4fe7c180da0a03baff4fc71

8481c8f592b7f349aa84a1de5c171db681516edfMXNet:

d126effd55a80c7df82aa0e96cb0f5cf7f1c5785

这里我们使用交叉熵函数和最简单随机梯度下降并使用固定学习率 0.1

训练

最后我们实现训练算法,并附上了输出结果。注意到每次我们会使用不同的权重和数据读取顺序,所以每次结果可能不一样。

8481c8f592b7f349aa84a1de5c171db681516edfPyTorch

37274d74bd5215f00a2a585afaea92d1eb809284

8481c8f592b7f349aa84a1de5c171db681516edfMXNet

fa0addb6b0a7d824307feb70fcd9eae4ea9e209a

MXNet 跟 PyTorch 的不同主要在下面这几点:

8481c8f592b7f349aa84a1de5c171db681516edf不需要将输入放进  Variable , 但需要将计算放在  mx.autograd.record()  里使得后面可以对其求导
8481c8f592b7f349aa84a1de5c171db681516edf 不需要每次梯度清 0,因为新梯度是写进去,而不是累加
8481c8f592b7f349aa84a1de5c171db681516edf step  的时候 MXNet 需要给定批量大小
8481c8f592b7f349aa84a1de5c171db681516edf需要调用  asscalar()  来将多维数组变成标量。
8481c8f592b7f349aa84a1de5c171db681516edf 这个样例里 MXNet 比 PyTorch 快两倍。当然大家对待这样的比较要谨慎。

下一步

8481c8f592b7f349aa84a1de5c171db681516edf 更详细的 MXNet 的教程:http://zh.gluon.ai/

8481c8f592b7f349aa84a1de5c171db681516edf欢迎给我们留言哪些 PyTorch 的方便之处你希望 MXNet 应该也可以有



原文发布时间为:2018-04-3

本文作者:李沐

本文来自云栖社区合作伙伴新智元,了解相关信息可以关注“AI_era”微信公众号

原文链接:【李沐】十分钟从 PyTorch 转 MXNet

相关文章
|
机器学习/深度学习 人工智能 PyTorch
李沐动手学深度学习pytorch :问题:找不到d2l包,No module named ‘d2l’
李沐动手学深度学习pytorch :问题:找不到d2l包,No module named ‘d2l’
1133 0
|
机器学习/深度学习 存储 算法
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(下)
【李沐:动手学深度学习pytorch版】第3章:线性神经网络
598 0
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(下)
|
机器学习/深度学习 算法 数据可视化
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(上)
【李沐:动手学深度学习pytorch版】第3章:线性神经网络
342 0
【李沐:动手学深度学习pytorch版】第3章:线性神经网络(上)
|
机器学习/深度学习 存储 数据可视化
【李沐:动手学深度学习pytorch版】第2章:预备知识(下)
【李沐:动手学深度学习pytorch版】第2章:预备知识
602 0
【李沐:动手学深度学习pytorch版】第2章:预备知识(下)
|
机器学习/深度学习 数据采集 算法
【李沐:动手学深度学习pytorch版】第2章:预备知识(上)
【李沐:动手学深度学习pytorch版】第2章:预备知识
426 0
|
机器学习/深度学习 人工智能 数据挖掘
李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
2440 0
李沐《动手学深度学习》PyTorch 实现版开源,瞬间登上 GitHub 热榜!
|
7月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
1182 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
3月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
204 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
9月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
772 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体

热门文章

最新文章

推荐镜像

更多