Shortcuts

torch.distributed.optim.zero_redundancy_optimizer 的源代码

# 版权所有 (c) Facebook, Inc. 及其附属公司。保留所有权利。
#
# 此源代码根据在
# LICENSE 文件中找到的 BSD 许可证授权。

r"""零冗余优化器。"""
import collections
import copy
import enum
import inspect
import io
import logging
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union

import torch
import torch.distributed as dist
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.distributed.optim.utils import functional_optim_map
from torch.optim import Optimizer


logger = logging.getLogger(__name__)

__all__ = ["ZeroRedundancyOptimizer"]


# 鸣谢:  classy_vision/generic/distributed_util.py
def _recursive_copy_to_device(
    value: Any,
    non_blocking: bool,
    device: torch.device,
) -> Any:
    r"""
    递归搜索列表、元组、字典并将张量复制到设备(如果可能)。

    非张量值将按原样传递。

    .. 注意: 这些都是副本,因此如果有两个对象引用
    同一个对象,那么在此调用之后,将有两个不同的对象
    在设备上引用。
    """
    if isinstance(value, torch.Tensor):
        return value.to(device, non_blocking=non_blocking)

    if isinstance(value, (list, tuple)):
        values = [
            _recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
            for val in value
        ]
        return values if isinstance(value, list) else tuple(values)

    if isinstance(value, collections.abc.Mapping):
        return {
            key: _recursive_copy_to_device(
                val, non_blocking=non_blocking, device=device
            )
            for key, val in value.items()
        }

    return value


def _is_trainable(param: torch.Tensor) -> bool:
    r"""返回一个参数是否可训练,其中可训练性等同于需要梯度。"""
    return param.requires_grad


def _broadcast_object(
    obj: Any,
    src_rank: int,
    group: object = dist.group.WORLD,
    device: torch.device = torch.device("cpu"),
) -> Any:
    r"""
    将对象广播到给定的组。

    如果从源排名调用,它将发送对象,否则将接收
    对象。

    参数:
        obj: 要广播的对象;仅在源排名上调用时使用。
        src_rank (int): 源排名。
        group (``ProcessGroup``, 可选): 用于广播的组
            (默认: ``dist.group.WORLD``)。
        device (``torch.device``, 可选): 发送或接收的设备
            (默认: ``torch.device("cpu")``)。

    返回:
        广播的对象。
    """
    if dist.get_rank() == src_rank:
        # 发送对象
        buffer = io.BytesIO()
        torch.save(obj, buffer)
        data = bytearray(buffer.getbuffer())
        length_tensor = torch.LongTensor([len(data)]).to(device)
        data_send_tensor = torch.ByteTensor(data).to(device)
        dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
        dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
    else:
        # 接收对象
        length_tensor = torch.LongTensor([0]).to(device)
        dist.broadcast(length_tensor, src=src_rank, group=group,</span