• Docs >
  • Generic Join Context Manager
Shortcuts

通用连接上下文管理器

通用的join上下文管理器有助于在输入不均匀的情况下进行分布式训练。本页概述了相关类的API:JoinJoinableJoinHook。有关教程,请参见 使用Join上下文管理器进行不均匀输入的分布式训练

class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[源代码]

此类定义了通用的连接上下文管理器,它允许在进程连接后调用自定义钩子。

这些钩子应该屏蔽非加入进程的集体通信,以防止挂起和错误,并确保算法的正确性。有关钩子定义的详细信息,请参阅 JoinHook

警告

上下文管理器要求每个参与的 Joinable 在每次迭代集体通信之前调用方法 notify_join_context() 以确保正确性。

警告

上下文管理器要求所有 process_group 属性在 JoinHook 对象中是相同的。如果有多个 JoinHook 对象,则使用第一个对象的 device。 进程组和设备信息用于检查未加入的进程,并在启用 throw_on_early_termination 时通知进程抛出异常,这两者都使用全归约(all-reduce)。

Parameters
  • 可连接对象 (列表[可连接对象]) – 参与的 可连接对象 的列表;它们的钩子按照给定的顺序进行迭代。

  • 启用 (布尔值) – 一个启用不均匀输入检测的标志;设置为 False 将禁用上下文管理器的功能,并且仅应在用户知道输入不会不均匀时设置 (默认值:True)。

  • throw_on_early_termination (bool) – 一个控制是否在检测到不均匀输入时抛出异常的标志(默认值:False)。

示例:

>>> import os
>>> import torch
>>> import torch.distributed as dist
>>> import torch.multiprocessing as mp
>>> import torch.nn.parallel.DistributedDataParallel as DDP
>>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO
>>> from torch.distributed.algorithms.join import Join
>>>
>>> # 在每个启动的工作进程上
>>> def worker(rank):
>>>     dist.init_process_group("nccl", rank=rank, world_size=2)
>>>     model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
>>>     optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01)
>>>     # 等级1比等级0多获得一个输入
>>>     inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)]
>>>     with Join([model, optim]):
>>>         for input in inputs:
>>>             loss = model(input).sum()
>>>             loss.backward()
>>>             optim.step()
>>>     # 所有等级都能到达这里而不会挂起/出错
static notify_join_context(joinable)[源代码]

通知加入上下文管理器,调用进程尚未加入。

然后,如果 throw_on_early_termination=True,检查是否检测到不均匀的输入(即如果一个进程已经加入),如果是,则抛出异常。

此方法应在每次迭代的集体通信之前从Joinable对象调用。例如,这应该在DistributedDataParallel的前向传播开始时调用。

只有第一个传入上下文管理器的 Joinable 对象在此方法中执行集体通信,其他对象则不执行任何操作。

Parameters

可连接的 (可连接的) – 调用此方法的 可连接的 对象。

Returns

一个异步工作句柄,用于通知上下文管理器,如果 joinable 是传递给上下文管理器的第一个参数,则表示进程尚未加入;否则为 None

class torch.distributed.algorithms.Joinable[源代码]

这定义了一个可连接类的抽象基类。

一个可连接的类(继承自Joinable)应该实现join_hook(), 该方法返回一个JoinHook实例,此外还应实现 join_device()join_process_group(),分别返回设备和 进程组信息。

abstract property join_device: 设备

返回用于执行连接上下文管理器所需的集体通信的设备。

abstract join_hook(**kwargs)[源代码]

返回给定 JoinableJoinHook 实例。

Parameters

kwargs (字典) – 一个包含任何关键字参数的字典,用于在运行时修改连接钩子的行为;所有共享相同连接上下文管理器的Joinable实例都会接收到相同的kwargs值。

Return type

JoinHook

abstract property join_process_group: Any

返回由连接上下文管理器本身所需的集体通信的进程组。

class torch.distributed.algorithms.JoinHook[源代码]

这定义了一个连接钩子,它在连接上下文管理器中提供了两个入口点。

入口点:一个主钩子,在存在未加入的进程时会重复调用,以及一个后钩子,在所有进程都加入后调用一次。

要为通用连接上下文管理器实现一个连接钩子,定义一个继承自JoinHook的类,并根据需要重写main_hook()post_hook()

main_hook()[源代码]

在训练迭代中存在未加入的进程时调用此钩子以跟踪集体通信。

训练迭代,即在一次前向传播、反向传播和优化器步骤中。

post_hook(is_last_joiner)[源代码]

所有进程加入后调用钩子。

它传递了一个额外的 bool 参数 is_last_joiner,该参数指示排名是否是最后加入的之一。

Parameters

is_last_joiner (bool) – True 如果排名是最后加入的之一;False 否则。