分布式通信#

MLX 使用 MPI 来提供分布式通信操作,这些操作允许将训练或推理的计算成本分摊到多台物理机器上。您可以在 API 文档 中查看支持的操作列表。

注意

许多操作可能不受支持或速度不如预期。我们正在添加更多功能并调整现有功能,以找到在Mac上使用MLX进行分布式计算的最佳方法。

入门指南#

如果机器上安装了MPI,MLX已经具备了与MPI“对话”的能力。MLX中最简单的分布式程序如下:

import mlx.core as mx

world = mx.distributed.init()
x = mx.distributed.all_sum(mx.ones(10))
print(world.rank(), x)

上面的程序在所有分布式进程中求和数组 mx.ones(10)。然而,如果仅使用 python 运行,则只会启动一个进程,不会进行任何分布式通信。

要在分布式模式下启动程序,我们需要使用mpirunmpiexec,具体取决于MPI的安装。最简单的方式如下:

$ mpirun -np 2 python test.py
1 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)
0 array([2, 2, 2, ..., 2, 2, 2], dtype=float32)

上述内容在同一台(本地)机器上启动了两个进程,我们可以看到两个标准输出流。这些进程将1的数组发送给对方并计算总和,然后打印出来。使用mpirun -np 4 ...启动会打印4等。

安装MPI#

MPI 可以通过 Homebrew 安装,使用 Anaconda 包管理器或从源代码编译。我们的大部分测试是使用通过 Anaconda 包管理器安装的 openmpi 完成的,如下所示:

$ conda install conda-forge::openmpi

使用Homebrew安装时可能需要指定libmpi.dyld的位置,以便MLX能够在运行时找到并加载它。这可以通过将DYLD_LIBRARY_PATH环境变量传递给mpirun来实现。

$ mpirun -np 2 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python test.py

设置远程主机#

MPI 可以自动连接到远程主机并在网络上建立通信,前提是可以通过 ssh 访问远程主机。以下是一个用于调试连接问题的良好检查清单:

  • ssh hostname 在所有机器之间都可以工作,无需输入密码或确认主机

  • mpirun 在所有机器上都可以访问。您可以使用其完整路径调用 mpirun,以强制所有机器使用特定路径。

  • 确保MPI使用的hostname是你在所有机器的.ssh/config文件中配置的那个。

注意

例如,对于主机名 foo.bar.com,如果当前主机名匹配 *.bar.com,MPI 只能使用 foo 作为传递给 ssh 的主机名。

将主机名传递给MPI的一种简单方法是使用主机文件。主机文件如下所示,其中host1host2应该是这些主机的完全限定域名或IP地址。

host1 slots=1
host2 slots=1

使用MLX时,您很可能希望每个主机使用1个插槽,即每个主机一个进程。如果您想在本地主机上运行,主机文件还需要包含当前主机。将主机文件传递给mpirun只需使用--hostfile命令行参数即可完成。

训练示例#

在本节中,我们将调整MLX训练循环以支持数据并行分布式训练。具体来说,我们将在将梯度应用于模型之前,在一组主机上对梯度进行平均。

如果我们省略模型、数据集和优化器的初始化,我们的训练循环看起来像以下代码片段。

model = ...
optimizer = ...
dataset = ...

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    optimizer.update(model, grads)
    return loss

for x, y in dataset:
    loss = step(model, x, y)
    mx.eval(loss, model.parameters())

我们只需要在所有机器上对梯度进行平均,执行一个all_sum()并除以Group的大小。也就是说,我们必须使用以下函数对梯度进行mlx.utils.tree_map()

def all_avg(x):
    return mx.distributed.all_sum(x) / mx.distributed.init().size()

将所有内容整合在一起,我们的训练循环步骤如下所示,其他所有内容保持不变。

from mlx.utils import tree_map

def all_reduce_grads(grads):
    N = mx.distributed.init().size()
    if N == 1:
        return grads
    return tree_map(
        lambda x: mx.distributed.all_sum(x) / N,
        grads
    )

def step(model, x, y):
    loss, grads = loss_grad_fn(model, x, y)
    grads = all_reduce_grads(grads)  # <--- This line was added
    optimizer.update(model, grads)
    return loss

调优All Reduce#

我们正在努力提高MLX上所有reduce操作的性能,但目前为了充分利用MLX进行分布式训练,可以做的两件主要事情是:

  1. 执行几次大的减少操作,而不是多次小的减少操作,以提高带宽和延迟

  2. 传递 --mca btl_tcp_links 4mpirun 以配置它在每台主机之间使用4个TCP连接来提高带宽