Shortcuts

torch.nn.parallel.data_parallel 的源代码

```html
import operator
import torch
import warnings
from itertools import chain
from typing import Any, Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union
from ..modules import Module
from .scatter_gather import scatter_kwargs, gather
from .replicate import replicate
from .parallel_apply import parallel_apply
from torch._utils import (
    _get_all_device_indices,
    _get_available_device_type,
    _get_device_index,
    _get_devices_properties
)

__all__ = ['DataParallel', 'data_parallel']

def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
    imbalance_warn = """
    你的GPU之间存在不平衡。你可能想要排除GPU {},它
    的内存或核心数少于GPU {} 的75% of。你可以通过设置
    DataParallel的device_ids参数,或者通过设置CUDA_VISIBLE_DEVICES
    环境变量来实现。"""
    device_ids = [_get_device_index(x, True) for x in device_ids]
    dev_props = _get_devices_properties(device_ids)

    def warn_imbalance(get_prop):
        values = [get_prop(props) for props in dev_props]
        min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
        max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
        if min_val / max_val < 0.75:
            warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
            return True
        return False

    if warn_imbalance(lambda props: props.total_memory):
        return
    if warn_imbalance(lambda props: props.multi_processor_count):
        return


T = TypeVar("T", bound=Module)


[docs]class DataParallel(Module, Generic[T]): r"""在模块级别实现数据并行。 此容器通过在批次维度上对输入进行分块,将给定的 :attr:`module` 并行化 (其他对象将在每个设备上复制一次)。在前向传递中,模块在每个设备上复制, 并且每个副本处理输入的一部分。在反向传递中,来自每个副本的梯度被汇总到原始模块中。 批量大小应大于使用的GPU数量。 .. 警告:: 建议使用 :class:`~torch.nn.parallel.DistributedDataParallel`, 而不是此类,来进行多GPU训练,即使只有一个节点。参见::ref:`cuda-nn-ddp-instead` 和 :ref:`ddp`。 允许传递任意位置和关键字输入,但某些类型会特别处理。张量将在指定的维度上**分散**(默认0)。 元组、列表和字典类型将被浅拷贝。其他类型将在不同线程之间共享, 如果在模型的前向传递中写入,可能会被破坏。 并行化的 :attr:`module` 必须在运行此 :class:`~torch.nn.DataParallel` 模块之前, 将其参数和缓冲区放在 ``device_ids[0]`` 上。 .. 警告:: 在每次前向中,:attr:`module` 在每个设备上**复制**,因此任何 在 ``forward`` 中对运行模块的更新都将丢失。例如, 如果 :attr:`module` 有一个在每次 ``forward`` 中递增的计数器属性, 它将始终保持在初始值,因为更新是在副本上完成的,副本在 ``forward`` 后被销毁。 然而,:class:`~torch.nn.DataParallel` 保证 ``device[0]`` 上的副本 将与其参数和缓冲区共享存储与基础并行化的 :attr:`module`。因此,**就地**更新 在 ``device[0]`` 上的参数或缓冲区将被记录。例如, :class:`~torch.nn.BatchNorm2d` 和 :func:`~torch.nn.utils.spectral_norm` 依赖于此行为来更新缓冲区。 .. 警告:: 在 :attr:`module` 及其子模块上定义的前向和后向钩子将被调用 ``len(device_ids)`` 次, 每次输入位于特定设备上。特别地,钩子仅保证相对于相应设备上的操作按正确顺序执行。 例如,不能保证通过 :meth:`~torch.nn.Module.register_forward_pre_hook` 设置的钩子 在所有 ``len(device_ids)`` 个 :meth:`~torch.nn.Module.forward` 调用之前执行, 但可以保证每个这样的钩子在相应设备的 :meth:`~torch.nn.Module.forward` 调用之前执行。 .. 警告:: 当 :attr:`module` 在 :func:`forward` 中返回标量(即0维张量)时, 此包装器将返回一个长度等于数据并行中使用的设备数量的向量, 包含每个设备的结果。 .. 注意:: 在使用 :class:`~torch.nn.DataParallel` 包装的 :class:`~torch.nn.Module` 中使用 ``pack sequence -> recurrent network -> unpack sequence`` 模式时, 有一个微妙之处。参见FAQ中的 :ref:`pack-rnn-unpack-with-data-parallelism` 部分以获取详细信息。 参数: module (Module): 要并行化的模块 device_ids (list of int or torch.device): CUDA设备(默认: 所有设备) output_device (int or torch.device): 输出设备位置(默认: device_ids[0]) 属性: module (Module): 要并行化的模块 示例:: >>> # xdoctest: +SKIP >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var 可以是任何设备,包括CPU """ # TODO: 当此类的处理能力达到8个以上GPU时,更新notes/cuda.rst def __init__( self, module: T, device_ids: Optional[Sequence[Union[int, torch.<span class