DataParallel¶
- class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)[源代码]¶
在模块级别实现数据并行性。
此容器通过在指定的设备上分割输入来并行化给定的
模块
的应用,方法是按批次维度进行分块(其他对象将在每个设备上复制一次)。在前向传递中,模块在每个设备上复制,每个副本处理输入的一部分。在反向传递期间,来自每个副本的梯度被汇总到原始模块中。批量大小应大于使用的GPU数量。
警告
建议使用
DistributedDataParallel
, 而不是这个类来进行多GPU训练,即使只有一个节点。参见:使用 nn.parallel.DistributedDataParallel 代替 multiprocessing 或 nn.DataParallel 和 分布式数据并行。允许传递任意位置和关键字输入到 DataParallel,但某些类型会特别处理。张量将在指定的维度上(默认0)分散。元组、列表和字典类型将被浅拷贝。其他类型将在不同线程之间共享,如果在模型的前向传播中写入,可能会被破坏。
并行化的
module
必须在运行此DataParallel
模块之前将其参数和缓冲区放在device_ids[0]
上。警告
在每次前向传播中,
module
会在每个设备上被复制,因此在forward
中对运行中的模块的任何更新都将丢失。例如,如果module
有一个在每次forward
中递增的计数器属性,它将始终保持在初始值,因为更新是在副本上进行的,这些副本在forward
之后会被销毁。然而,DataParallel
保证device[0]
上的副本将与基础并行化的module
共享存储参数和缓冲区。因此,在device[0]
上的参数或缓冲区的就地更新将被记录。例如,BatchNorm2d
和spectral_norm()
依赖于此行为来更新缓冲区。警告
在
module
及其子模块上定义的前向和后向钩子将被调用len(device_ids)
次,每次输入位于特定设备上。特别地,钩子仅保证在与相应设备上的操作顺序一致的情况下执行。例如,不能保证通过register_forward_pre_hook()
设置的钩子在所有len(device_ids)
次forward()
调用之前执行,但可以保证每个这样的钩子在该设备的相应forward()
调用之前执行。警告
当
module
返回一个标量(即,0维张量)时,forward()
,此包装器将返回一个长度等于 数据并行性中使用的设备数量的向量,其中包含来自 每个设备的结果。注意
在使用
pack sequence -> recurrent network -> unpack sequence
模式时,有一个细微之处需要注意,特别是在Module
被DataParallel
包装的情况下。详情请参见FAQ中的我的循环网络在数据并行性下无法工作部分。- Parameters
模块 (模块) – 要并行化的模块
device_ids (列表 的 整数 或 torch.device) – CUDA 设备(默认值:所有设备)
output_device (int 或 torch.device) – 输出设备位置(默认值:device_ids[0])
- Variables
模块 (模块) – 要并行化的模块
示例:
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) >>> output = net(input_var) # input_var 可以是任何设备,包括 CPU