分布式通信#
MLX 使用 MPI 来提供分布式通信操作,这些操作允许将训练或推理的计算成本分摊到多台物理机器上。您可以在 API 文档 中查看支持的操作列表。
注意
许多操作可能不受支持或速度不如预期。我们正在添加更多功能并调整现有功能,以找到在Mac上使用MLX进行分布式计算的最佳方法。
入门指南#
如果机器上安装了MPI,MLX已经具备了与MPI“对话”的能力。MLX中最简单的分布式程序如下:
import mlx.core as mx
world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)
上面的程序在所有分布式进程中求和数组 mx.ones(10)
。然而,如果仅使用 python
运行,则只会启动一个进程,不会进行任何分布式通信。
要在分布式模式下启动程序,我们需要使用mpirun
或mpiexec
,具体取决于MPI的安装。最简单的方式如下:
$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
上述内容在同一台(本地)机器上启动了两个进程,我们可以看到两个标准输出流。这些进程将1的数组发送给对方并计算总和,然后打印出来。使用mpirun -np 4 ...
启动会打印4等。
安装MPI#
MPI 可以通过 Homebrew 安装,使用 Anaconda 包管理器或从源代码编译。我们的大部分测试是使用通过 Anaconda 包管理器安装的 openmpi
完成的,如下所示:
$ conda install conda-forge::openmpi
使用Homebrew安装时可能需要指定libmpi.dyld
的位置,以便MLX能够在运行时找到并加载它。这可以通过将DYLD_LIBRARY_PATH
环境变量传递给mpirun
来实现。
$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py
设置远程主机#
MPI 可以自动连接到远程主机并在网络上建立通信,前提是可以通过 ssh 访问远程主机。以下是一个用于调试连接问题的良好检查清单:
ssh hostname
在所有机器之间都可以工作,无需输入密码或确认主机mpirun
在所有机器上都可以访问。您可以使用其完整路径调用mpirun
,以强制所有机器使用特定路径。确保MPI使用的
hostname
是你在所有机器的.ssh/config
文件中配置的那个。
注意
例如,对于主机名 foo.bar.com
,如果当前主机名匹配 *.bar.com
,MPI 只能使用 foo
作为传递给 ssh 的主机名。
将主机名传递给MPI的一种简单方法是使用主机文件。主机文件如下所示,其中host1
和host2
应该是这些主机的完全限定域名或IP地址。
host1 slots=1
host2 slots=1
使用MLX时,您很可能希望每个主机使用1个插槽,即每个主机一个进程。如果您想在本地主机上运行,主机文件还需要包含当前主机。将主机文件传递给mpirun
只需使用--hostfile
命令行参数即可完成。
训练示例#
在本节中,我们将调整MLX训练循环以支持数据并行分布式训练。具体来说,我们将在将梯度应用于模型之前,在一组主机上对梯度进行平均。
如果我们省略模型、数据集和优化器的初始化,我们的训练循环看起来像以下代码片段。
model = ...
optimizer = ...
dataset = ...
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
optimizer.update(model, grads)
return loss
for x, y in dataset:
loss = step(model, x, y)
mx.eval(loss, model.parameters())
我们只需要在所有机器上对梯度进行平均,执行一个all_sum()
并除以Group
的大小。也就是说,我们必须使用以下函数对梯度进行mlx.utils.tree_map()
。
def all_avg(x):
return mx.distributed.all_sum(x) / mx.distributed.init().size()
将所有内容整合在一起,我们的训练循环步骤如下所示,其他所有内容保持不变。
from mlx.utils import tree_map
def all_reduce_grads(grads):
N = mx.distributed.init().size()
if N == 1:
return grads
return tree_map(
lambda x: mx.distributed.all_sum(x) / N,
grads
)
def step(model, x, y):
loss, grads = loss_grad_fn(model, x, y)
grads = all_reduce_grads(grads) # <--- This line was added
optimizer.update(model, grads)
return loss
调优All Reduce#
我们正在努力提高MLX上所有reduce操作的性能,但目前为了充分利用MLX进行分布式训练,可以做的两件主要事情是:
执行几次大的减少操作,而不是多次小的减少操作,以提高带宽和延迟
传递
--mca btl_tcp_links 4
给mpirun
以配置它在每台主机之间使用4个TCP连接来提高带宽