torch.distributed.algorithms.ddp_comm_hooks.default_hooks 的源代码
from typing import Any, Callable, cast, Tuple
import torch
import torch.distributed as dist
__all__ = [
"allreduce_hook",
"fp16_compress_hook",
"bf16_compress_hook",
"fp16_compress_wrapper",
"bf16_compress_wrapper",
]
def _allreduce_fut(
process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future[torch.Tensor]:
"""通过allreduce平均输入梯度张量并返回一个future。"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
# 先应用除法以避免溢出,特别是对于FP16。
tensor.div_(group_to_use.size())
return (
dist.all_reduce(tensor, group=group_to_use, async_op=True)
.get_future()
.then(lambda fut: fut.value()[0])
)
[docs]def allreduce_hook(
process_group: dist.ProcessGroup, bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
"""
使用``GradBucket``张量调用``allreduce``。
一旦梯度张量在所有工作节点上聚合,其``then``回调函数取平均值并返回结果。
如果用户注册了此DDP通信钩子,
DDP结果预期与未注册钩子的情况相同。
因此,这不会改变DDP的行为,用户可以将其用作参考
或在不影响DDP行为的情况下修改此钩子以记录有用信息或用于其他目的。
示例::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, allreduce_hook)
"""
return _allreduce_fut(process_group, bucket.buffer())
[docs]def fp16_compress_hook(
process_group: dist.ProcessGroup,
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
"""
通过将``GradBucket``转换为``torch.float16``并除以进程组大小来进行压缩。
此DDP通信钩子实现了一种简单的梯度压缩
方法,将``GradBucket``张量转换为半精度浮点格式(``torch.float16``)
然后将其除以进程组大小。
它对这些``float16``梯度张量进行allreduce。一旦压缩的梯度
张量被allreduce,链式回调``decompress``将其转换回输入数据类型(如``float32``)。
示例::
>>> # xdoctest: +SKIP
>>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
world_size = group_to_use.size()
buffer = (
cast(Tuple[torch.Tensor, ...], bucket)[0]
if isinstance(bucket, tuple)
else bucket.buffer()
)
compressed_tensor = buffer.to(torch.float16).div_(world_size)
def decompress(fut):
decompressed_tensor = buffer
# 就地解压缩以减少峰值内存。
# 参见:https://github.com/pytorch/pytorch/issues/45968
value = fut if isinstance(fut, torch.Tensor) else fut.value()[0]
decompressed_tensor.copy_(value)
return decompressed_tensor
if torch._utils.is_compiling():
grad = dist._functional_collectives.all_reduce(
compressed_tensor, "sum", group_to_use
)
return decompress(grad)
else:
fut = dist.all_reduce(
compressed_tensor, group=group_to_use, async_op=True
).get_future()
return fut.then(decompress)
# TODO: 创建一个内部辅助函数并提取FP16_compress和BF16_compress中的重复代码。
[docs]def bf16_compress_hook(
process_group: dist.ProcessGroup<span class