MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 如果你是一个Mac用户和一个深度学习爱好者,你可能希望在某些时候Mac可以处理一些重型模型。苹果刚刚发布了MLX,一个在苹果芯片上高效运行机器学习模型的框架。

最近在PyTorch 1.12中引入MPS后端已经是一个大胆的步骤,但随着MLX的宣布,苹果还想在开源深度学习方面有更大的发展。

在本文中,我们将对这些新方法进行测试,在三种不同的Apple Silicon芯片和两个支持cuda的gpu上和传统CPU后端进行基准测试。

这里把基准测试集中在图卷积网络(GCN)模型上。这个模型主要由线性层组成,所以对于其他的模型也应该得到类似的结果。

创造环境

要为MLX构建环境,我们必须指定是使用i386还是arm架构。使用conda,可以使用:

 CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forge
 conda activate mlx

如果检查你的env是否实际使用了arm,下面命令的输出应该是arm,而不是i386(因为我们用的Apple Silicon):

 python -c "import platform; print(platform.processor())"

然后就是使用pip安装MLX:

 pip install mlx

GCN模型

GCN模型是图神经网络(GNN)的一种,它使用邻接矩阵(表示图结构)和节点特征。它通过收集邻近节点的信息来计算节点嵌入。每个节点获得其邻居特征的平均值。这种平均是通过将节点特征与标准化邻接矩阵相乘来完成的,并根据节点度进行调整。为了学习这个过程,特征首先通过线性层投射到嵌入空间中。

我们将使用MLX实现一个GCN层和一个GCN模型:

 import mlx.nn as nn

 class GCNLayer(nn.Module):
     def __init__(self, in_features, out_features, bias=True):
         super(GCNLayer, self).__init__()
         self.linear = nn.Linear(in_features, out_features, bias)

     def __call__(self, x, adj):
         x = self.linear(x)
         return adj @ x

 class GCN(nn.Module):
     def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):
         super(GCN, self).__init__()

         layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]
         self.gcn_layers = [
             GCNLayer(in_dim, out_dim, bias)
             for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])
         ]
         self.dropout = nn.Dropout(p=dropout)

     def __call__(self, x, adj):
         for layer in self.gcn_layers[:-1]:
             x = nn.relu(layer(x, adj))
             x = self.dropout(x)

         x = self.gcn_layers[-1](x, adj)
         return x

可以看到,mlx的模型开发方式与tf2基本一样,都是调用

__call__

进行前向传播,其实torch也一样,只不过它自定义了一个forward函数。

下面就是训练

 gcn = GCN(
     x_dim=x.shape[-1],
     h_dim=args.hidden_dim,
     out_dim=args.nb_classes,
     nb_layers=args.nb_layers,
     dropout=args.dropout,
     bias=args.bias,
 )
 mx.eval(gcn.parameters())

 optimizer = optim.Adam(learning_rate=args.lr)
 loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)

 # Training loop
 for epoch in range(args.epochs):

     # Loss
     (loss, y_hat), grads = loss_and_grad_fn(
         gcn, x, adj, y, train_mask, args.weight_decay
     )
     optimizer.update(gcn, grads)
     mx.eval(gcn.parameters(), optimizer.state)

     # Validation
     val_loss = loss_fn(y_hat[val_mask], y[val_mask])
     val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,计算是惰性的,这意味着eval()通常用于在更新后实际计算新的模型参数。而另一个关键函数是nn.value_and_grad(),它生成一个计算参数损失的函数。第一个参数是保存当前参数的模型,第二个参数是用于前向传递和损失计算的可调用函数。它返回的函数接受与forward函数相同的参数(在本例中为forward_fn)。我们可以这样定义这个函数:

 def forward_fn(gcn, x, adj, y, train_mask, weight_decay):
     y_hat = gcn(x, adj)
     loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())
     return loss, y_hat

它仅仅包括计算前向传递和计算损失。Loss_fn()和eval_fn()定义如下:

 def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):
     l = mx.mean(nn.losses.cross_entropy(y_hat, y))

     if weight_decay != 0.0:
         assert parameters != None, "Model parameters missing for L2 reg."

         l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()
         return l + weight_decay * l2_reg

     return l

 def eval_fn(x, y):
     return mx.mean(mx.argmax(x, axis=1) == y)

损失函数是计算预测和标签之间的交叉熵,并包括L2正则化。由于L2正则化还不是内置特性,需要手动实现。

本文的完整代码:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些细节函数调用的差别,基本的训练流程与pytorch和tf都很类似,但是这里的一个很好的事情是消除了显式地将对象分配给特定设备的需要,就像我们在PyTorch中经常使用.cuda()和.to(device)那样。这是因为苹果硅芯片的统一内存架构,所有变量共存于同一空间,也就是说消除了CPU和GPU之间缓慢的数据传输,这样也可以保证不会再出现与设备不匹配相关的烦人的运行时错误。

基准测试

我们将使用MLX与MPS, CPU和GPU设备进行比较。我们的测试平台是一个2层GCN模型,应用于Cora数据集,其中包括2708个节点和5429条边。

对于MLX, MPS和CPU测试,我们对M1 Pro, M2 Ultra和M3 Max进行基准测试。在两款NVIDIA V100 PCIe和V100 NVLINK上进行测试

MPS:比M1 Pro的CPU快2倍以上,在其他两个芯片上,与CPU相比有30-50%的改进。

MLX:比M1 Pro上的MPS快2.34倍。与MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之间没有真正的改进。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max与MLX,这里的原因可能是因为我们的模型比较小,所以发挥不出V100和NVLINK的优势(NVLINK主要GPU之间的数据传输大的情况下会有提高)。这也说明了苹果的统一内存架构的确可以消除CPU和GPU之间缓慢的数据传输。

总结

与CPU和MPS相比,MLX可以说是非常大的进步,在小数据量的情况下它甚至接近特斯拉V100的性能。也就是说我们可以使用MLX跑一些不是那么大的模型,比如一些表格数据。

从上面的基准测试也可以看到,现在可以利用苹果芯片的全部力量在本地运行深度学习模型(我一直认为MPS还没发挥苹果的优势,这回MPS已经证明了这一点)。

MLX刚刚发布就已经取得了惊人的影响力,并展示了巨大的潜力。相信未来几年开源社区的进一步增强,可以期待在不久的将来更强大的苹果芯片,将MLX的性能提升到一个全新的水平。

另外也说明了MPS(虽然也发布不久)还是有巨大的发展空间的,毕竟切换框架是一件很麻烦的事情,如果MPS能达到MLX 80%或者90%的速度,我想不会有人去换框架的。

最后说到框架,现在已经有了Pytorch,TF,JAX,现在又多了一个MLX。各种设备、各种后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,现在又多了一个MPS。

https://avoid.overfit.cn/post/eb87d12f29eb4665adb43ad59fd3d64f

目录
相关文章
|
8天前
|
数据可视化 数据管理 测试技术
聊聊自动化测试框架
关于自动化测试框架的一些理解和思考总结,就是上面这些内容,提到的一些框架组件可能存在不合理的地方,仅供参考,如有更好的建议,请指出,不胜感激
20 4
聊聊自动化测试框架
|
3天前
|
敏捷开发 IDE 测试技术
自动化测试框架的选择与应用
【9月更文挑战第16天】在软件开发周期中,测试环节扮演着至关重要的角色。随着敏捷开发和持续集成的流行,自动化测试成为提升软件质量和效率的关键手段。本文将探讨如何根据项目需求选择合适的自动化测试框架,并通过实际案例分析展示其在软件开发过程中的应用。我们将从单元测试、集成测试到端到端测试等多个层面,讨论自动化测试的最佳实践和常见问题解决策略。
|
4天前
|
机器学习/深度学习 人工智能 测试技术
自动化测试的未来:AI与机器学习的融合之路
【9月更文挑战第15天】在软件测试领域,自动化一直被视为提高效率和精确度的关键。随着人工智能(AI)和机器学习(ML)技术的不断进步,它们已经开始改变自动化测试的面貌。本文将探讨AI和ML如何赋能自动化测试,提升测试用例的智能生成、优化测试流程,并预测未来趋势。我们将通过实际代码示例来揭示这些技术如何被集成到现有的测试框架中,以及开发人员如何利用它们来提高软件质量。
32 15
|
10天前
|
机器学习/深度学习 Python
训练集、测试集与验证集:机器学习模型评估的基石
在机器学习中,数据集通常被划分为训练集、验证集和测试集,以评估模型性能并调整参数。训练集用于拟合模型,验证集用于调整超参数和防止过拟合,测试集则用于评估最终模型性能。本文详细介绍了这三个集合的作用,并通过代码示例展示了如何进行数据集的划分。合理的划分有助于提升模型的泛化能力。
|
14天前
|
存储 Java 关系型数据库
“代码界的魔法师:揭秘Micronaut框架下如何用测试驱动开发将简单图书管理系统变成性能怪兽!
【9月更文挑战第6天】Micronaut框架凭借其轻量级和高性能特性,在Java应用开发中备受青睐。本文通过一个图书管理系统的案例,介绍了在Micronaut下从单元测试到集成测试的全流程。首先,我们使用`@MicronautTest`注解编写了一个简单的`BookService`单元测试,验证添加图书功能;接着,通过集成测试验证了`BookService`与数据库的交互。整个过程展示了Micronaut强大的依赖注入和测试支持,使测试编写变得更加高效和简单。
32 4
|
17天前
|
IDE 测试技术 持续交付
Python自动化测试与单元测试框架:提升代码质量与效率
【9月更文挑战第3天】随着软件行业的迅速发展,代码质量和开发效率变得至关重要。本文探讨了Python在自动化及单元测试中的应用,介绍了Selenium、Appium、pytest等自动化测试框架,以及Python标准库中的unittest单元测试框架。通过详细阐述各框架的特点与使用方法,本文旨在帮助开发者掌握编写高效测试用例的技巧,提升代码质量与开发效率。同时,文章还提出了制定测试计划、持续集成与测试等实践建议,助力项目成功。
43 5
|
19天前
|
测试技术 C# 图形学
掌握Unity调试与测试的终极指南:从内置调试工具到自动化测试框架,全方位保障游戏品质不踩坑,打造流畅游戏体验的必备技能大揭秘!
【9月更文挑战第1天】在开发游戏时,Unity 引擎让创意变为现实。但软件开发中难免遇到 Bug,若不解决,将严重影响用户体验。调试与测试成为确保游戏质量的最后一道防线。本文介绍如何利用 Unity 的调试工具高效排查问题,并通过 Profiler 分析性能瓶颈。此外,Unity Test Framework 支持自动化测试,提高开发效率。结合单元测试与集成测试,确保游戏逻辑正确无误。对于在线游戏,还需进行压力测试以验证服务器稳定性。总之,调试与测试贯穿游戏开发全流程,确保最终作品既好玩又稳定。
40 4
|
20天前
|
测试技术 API 开发者
Python 魔法:打造你的第一个天气查询小工具自动化测试框架的构建与实践
【8月更文挑战第31天】在这篇文章中,我们将一起踏上编程的奇妙旅程。想象一下,只需几行代码,就能让计算机告诉你明天是否要带伞。是的,你没有听错,我们将用Python这把钥匙,解锁天气预报的秘密。不论你是编程新手还是想拓展技能的老手,这篇文章都会为你带来新的视角和灵感。所以,拿起你的键盘,让我们一起创造属于自己的天气小工具吧!
|
19天前
|
测试技术 C# 开发者
“代码守护者:详解WPF开发中的单元测试策略与实践——从选择测试框架到编写模拟对象,全方位保障你的应用程序质量”
【8月更文挑战第31天】单元测试是确保软件质量的关键实践,尤其在复杂的WPF应用中更为重要。通过为每个小模块编写独立测试用例,可以验证代码的功能正确性并在早期发现错误。本文将介绍如何在WPF项目中引入单元测试,并通过具体示例演示其实施过程。首先选择合适的测试框架如NUnit或xUnit.net,并利用Moq模拟框架隔离外部依赖。接着,通过一个简单的WPF应用程序示例,展示如何模拟`IUserRepository`接口并验证`MainViewModel`加载用户数据的正确性。这有助于确保代码质量和未来的重构与扩展。
28 0
|
19天前
|
测试技术 Java Spring
Spring 框架中的测试之道:揭秘单元测试与集成测试的双重保障,你的应用真的安全了吗?
【8月更文挑战第31天】本文以问答形式深入探讨了Spring框架中的测试策略,包括单元测试与集成测试的有效编写方法,及其对提升代码质量和可靠性的重要性。通过具体示例,展示了如何使用`@MockBean`、`@SpringBootTest`等注解来进行服务和控制器的测试,同时介绍了Spring Boot提供的测试工具,如`@DataJpaTest`,以简化数据库测试流程。合理运用这些测试策略和工具,将助力开发者构建更为稳健的软件系统。
27 0