通用连接上下文管理器¶
通用的join上下文管理器有助于在输入不均匀的情况下进行分布式训练。本页概述了相关类的API:Join
、
Joinable
和JoinHook
。有关教程,请参见
使用Join上下文管理器进行不均匀输入的分布式训练。
- class torch.distributed.algorithms.Join(joinables, enable=True, throw_on_early_termination=False, **kwargs)[源代码]¶
此类定义了通用的连接上下文管理器,它允许在进程连接后调用自定义钩子。
这些钩子应该屏蔽非加入进程的集体通信,以防止挂起和错误,并确保算法的正确性。有关钩子定义的详细信息,请参阅
JoinHook
。警告
上下文管理器要求每个参与的
Joinable
在每次迭代集体通信之前调用方法notify_join_context()
以确保正确性。警告
上下文管理器要求所有
process_group
属性在JoinHook
对象中是相同的。如果有多个JoinHook
对象,则使用第一个对象的device
。 进程组和设备信息用于检查未加入的进程,并在启用throw_on_early_termination
时通知进程抛出异常,这两者都使用全归约(all-reduce)。- Parameters
示例:
>>> import os >>> import torch >>> import torch.distributed as dist >>> import torch.multiprocessing as mp >>> import torch.nn.parallel.DistributedDataParallel as DDP >>> import torch.distributed.optim.ZeroRedundancyOptimizer as ZeRO >>> from torch.distributed.algorithms.join import Join >>> >>> # 在每个启动的工作进程上 >>> def worker(rank): >>> dist.init_process_group("nccl", rank=rank, world_size=2) >>> model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank]) >>> optim = ZeRO(model.parameters(), torch.optim.Adam, lr=0.01) >>> # 等级1比等级0多获得一个输入 >>> inputs = [torch.tensor([1.]).to(rank) for _ in range(10 + rank)] >>> with Join([model, optim]): >>> for input in inputs: >>> loss = model(input).sum() >>> loss.backward() >>> optim.step() >>> # 所有等级都能到达这里而不会挂起/出错
- class torch.distributed.algorithms.Joinable[源代码]¶
这定义了一个可连接类的抽象基类。
一个可连接的类(继承自
Joinable
)应该实现join_hook()
, 该方法返回一个JoinHook
实例,此外还应实现join_device()
和join_process_group()
,分别返回设备和 进程组信息。