Shortcuts

torch.distributed.algorithms.ddp_comm_hooks.default_hooks 的源代码

from typing import Any, Callable, cast, Tuple

import torch
import torch.distributed as dist

__all__ = [
    "allreduce_hook",
    "fp16_compress_hook",
    "bf16_compress_hook",
    "fp16_compress_wrapper",
    "bf16_compress_wrapper",
]


def _allreduce_fut(
    process_group: dist.ProcessGroup, tensor: torch.Tensor
) -> torch.futures.Future[torch.Tensor]:
    """通过allreduce平均输入梯度张量并返回一个future。"""
    group_to_use = process_group if process_group is not None else dist.group.WORLD

    # 先应用除法以避免溢出,特别是对于FP16。
    tensor.div_(group_to_use.size())

    return (
        dist.all_reduce(tensor, group=group_to_use, async_op=True)
        .get_future()
        .then(lambda fut: fut.value()[0])
    )


[docs]def allreduce_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: """ 使用``GradBucket``张量调用``allreduce``。 一旦梯度张量在所有工作节点上聚合,其``then``回调函数取平均值并返回结果。 如果用户注册了此DDP通信钩子, DDP结果预期与未注册钩子的情况相同。 因此,这不会改变DDP的行为,用户可以将其用作参考 或在不影响DDP行为的情况下修改此钩子以记录有用信息或用于其他目的。 示例:: >>> # xdoctest: +SKIP >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ return _allreduce_fut(process_group, bucket.buffer())
[docs]def fp16_compress_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket, ) -> torch.futures.Future[torch.Tensor]: """ 通过将``GradBucket``转换为``torch.float16``并除以进程组大小来进行压缩。 此DDP通信钩子实现了一种简单的梯度压缩 方法,将``GradBucket``张量转换为半精度浮点格式(``torch.float16``) 然后将其除以进程组大小。 它对这些``float16``梯度张量进行allreduce。一旦压缩的梯度 张量被allreduce,链式回调``decompress``将其转换回输入数据类型(如``float32``)。 示例:: >>> # xdoctest: +SKIP >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = group_to_use.size() buffer = ( cast(Tuple[torch.Tensor, ...], bucket)[0] if isinstance(bucket, tuple) else bucket.buffer() ) compressed_tensor = buffer.to(torch.float16).div_(world_size) def decompress(fut): decompressed_tensor = buffer # 就地解压缩以减少峰值内存。 # 参见:https://github.com/pytorch/pytorch/issues/45968 value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] decompressed_tensor.copy_(value) return decompressed_tensor if torch._utils.is_compiling(): grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", group_to_use ) return decompress(grad) else: fut = dist.all_reduce( compressed_tensor, group=group_to_use, async_op=True ).get_future() return fut.then(decompress)
# TODO: 创建一个内部辅助函数并提取FP16_compress和BF16_compress中的重复代码。
[docs]def bf16_compress_hook( process_group: dist.ProcessGroup<span class