Shortcuts

模型并行

DistributedModelParallel 是用于分布式训练的主要API,带有TorchRec优化。

class torchrec.distributed.model_parallel.DistributedModelParallel(module: Module, env: Optional[ShardingEnv] = None, device: Optional[device] = None, plan: Optional[ShardingPlan] = None, sharders: Optional[List[ModuleSharder[Module]]] = None, init_data_parallel: bool = True, init_parameters: bool = True, data_parallel_wrapper: Optional[DataParallelWrapper] = None)

模型并行的入口点。

Parameters:
  • 模块 (nn.Module) – 要包装的模块。

  • env (可选[ShardingEnv]) – 包含进程组的共享环境。

  • device (可选[torch.device]) – 计算设备,默认为cpu。

  • plan (可选[ShardingPlan]) – 分片时使用的计划,默认为 EmbeddingShardingPlanner.collective_plan()

  • sharders (可选[列表[ModuleSharder[nn.Module]]]) – ModuleSharders 可用于分片,默认为 EmbeddingBagCollectionSharder()

  • init_data_parallel (bool) – 数据并行模块可以是惰性的,即它们会延迟参数初始化直到第一次前向传递。传递True以延迟数据并行模块的初始化。先进行第一次前向传递,然后调用DistributedModelParallel.init_data_parallel()。

  • init_parameters (bool) – 为仍在元设备上的模块初始化参数。

  • data_parallel_wrapper (可选[DataParallelWrapper]) – 用于数据并行模块的自定义包装器。

示例:

@torch.no_grad()
def init_weights(m):
    if isinstance(m, nn.Linear):
        m.weight.fill_(1.0)
    elif isinstance(m, EmbeddingBagCollection):
        for param in m.parameters():
            init.kaiming_normal_(param)

m = MyModel(device='meta')
m = DistributedModelParallel(m)
m.apply(init_weights)
copy(device: device) 分布式模型并行

通过调用每个模块定制的复制过程,递归地将子模块复制到新设备,因为某些模块需要使用原始引用(如用于推理的ShardedModule)。

forward(*args, **kwargs) Any

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。

init_data_parallel() None

请参阅init_data_parallel c-tor参数以了解用法。 多次调用此方法是安全的。

load_state_dict(state_dict: OrderedDict[str, Tensor], prefix: str = '', strict: bool = True) _IncompatibleKeys

state_dict复制参数和缓冲区到这个模块及其子模块。

如果 strictTrue,那么 state_dict 的键必须与此模块的 state_dict() 函数返回的键完全匹配。

警告

如果 assignTrue,则必须在调用 load_state_dict 之后创建优化器,除非 get_swap_module_params_on_conversion()True

Parameters:
  • state_dict (dict) – 一个包含参数和持久缓冲区的字典。

  • strict (bool, optional) – 是否严格强制执行state_dict中的键与此模块的state_dict()函数返回的键匹配。默认值:True

  • assign (bool, optional) – 当 False 时,当前模块中的张量属性会被保留,而当 True 时,状态字典中的张量属性会被保留。唯一的例外是 requires_grad 字段,其默认值为 Default: ``False`

Returns:

  • missing_keys 是一个包含任何预期键的字符串列表

    由该模块提供但在提供的 state_dict 中缺失。

  • unexpected_keys 是一个包含不预期键的字符串列表

    由该模块提供但在提供的 state_dict 中存在。

Return type:

NamedTuple 包含 missing_keysunexpected_keys 字段

注意

如果一个参数或缓冲区被注册为None,并且其对应的键存在于state_dict中,load_state_dict()将会抛出一个RuntimeError

property module: Module

属性直接访问分片模块,该模块不会被DDP、FSDP、DMP或任何其他并行化包装器包装。

named_buffers(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Tensor]]

返回一个模块缓冲区的迭代器,生成缓冲区的名称以及缓冲区本身。

Parameters:
  • prefix (str) – 在所有缓冲区名称前添加的前缀。

  • recurse (bool, optional) – 如果为True,则生成此模块及其所有子模块的缓冲区。否则,仅生成直接属于此模块的缓冲区。默认为True。

  • remove_duplicate (bool, optional) – 是否移除结果中的重复缓冲区。默认为 True。

Yields:

(str, torch.Tensor) – 包含名称和缓冲区的元组

示例:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, buf in self.named_buffers():
>>>     if name in ['running_var']:
>>>         print(buf.size())
named_parameters(prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) Iterator[Tuple[str, Parameter]]

返回一个模块参数的迭代器,生成参数的名称以及参数本身。

Parameters:
  • prefix (str) – 在所有参数名称前添加的前缀。

  • recurse (bool) – 如果为True,则生成此模块的参数和所有子模块的参数。否则,仅生成此模块的直接成员参数。

  • remove_duplicate (bool, 可选) – 是否移除结果中重复的参数。默认为 True。

Yields:

(str, Parameter) – 包含名称和参数的元组

示例:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>     if name in ['bias']:
>>>         print(param.size())
state_dict(destination: Optional[Dict[str, Any]] = None, prefix: str = '', keep_vars: bool = False) Dict[str, Any]

返回一个包含对模块整个状态的引用的字典。

参数和持久缓冲区(例如运行平均值)都包括在内。键是对应的参数和缓冲区名称。设置为None的参数和缓冲区不包括在内。

注意

返回的对象是一个浅拷贝。它包含对模块参数和缓冲区的引用。

警告

目前state_dict()也按顺序接受destinationprefixkeep_vars的位置参数。然而,这种做法已被弃用,未来的版本将强制使用关键字参数。

警告

请避免使用参数destination,因为它不是为最终用户设计的。

Parameters:
  • destination (dict, optional) – 如果提供了,模块的状态将被更新到字典中,并返回相同的对象。否则,将创建并返回一个OrderedDict。默认值:None

  • prefix (str, optional) – 一个前缀,添加到参数和缓冲区的名称中,以构成state_dict中的键。默认值:''

  • keep_vars (bool, optional) – 默认情况下,状态字典中返回的 Tensor 会从自动求导中分离。如果设置为 True,则不会执行分离操作。 默认值:False

Returns:

包含模块完整状态的字典

Return type:

字典

示例:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']