torch.distributed.tensor.parallel.loss 的源代码
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
import contextlib
from typing import cast, Dict, Optional, Tuple
import torch
import torch._prims_common as utils
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch import Tensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
from torch.distributed._tensor.ops.math_ops import (
_skip_dim,
Reduction,
replicate_reduction_dims,
)
from torch.distributed._tensor.placement_types import Placement, TensorMeta
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten
__all__ = ["loss_parallel"]
[docs]@contextlib.contextmanager
def loss_parallel():
"""
一个上下文管理器,启用损失并行计算,当输入在类别维度上被分片时,可以执行高效的并行损失计算。
目前仅支持交叉熵损失。
在此上下文管理器中,可以像往常一样使用 :func:`~torch.nn.functional.cross_entropy` 或
:class:`~torch.nn.CrossEntropyLoss`,对输入参数有以下假设。
相应的 ``backward()`` 调用(如果有)也需要在此上下文管理器下进行。
参数:
input (:class:`DTensor`):
输入的 logits。假设在类别维度上被分片。
target (Union[:class:`torch.Tensor`, :class:`DTensor`]):
必须是真实的类别索引(目前不支持类别概率)。
假设在 ``DeviceMesh`` 上被复制。
weight (Union[:class:`torch.Tensor`, :class:`DTensor`], 可选):
如果给出,假设在 ``DeviceMesh`` 上被复制。
label_smoothing:
目前不支持。
返回:
一个被复制的 :class:`DTensor`。
示例:
这里手动创建了一个分片的 DTensor 来展示用法。
在实践中,它通常是一个 TP 模块的输出。
>>> # xdoctest: +SKIP("distributed")
>>> from torch.distributed.tensor.parallel import loss_parallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> device_mesh = init_device_mesh("cuda", (8,))
>>> input = torch.randn(4, 16, device="cuda", requires_grad=True)
>>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)])
>>> target = torch.randint(16, (4,), device="cuda")
>>> with loss_parallel():
>>> loss = F.cross_entropy(dist_input, target, reduction="mean")
>>> loss.backward()
>>> ...
"""
_enable_custom_loss_ops()
yield
_disable_custom_loss_ops()
# 目前只需要支持一维的 DeviceMesh;通常返回
# 具有 placements[mesh_dim].is_shard(dim) 的 mesh_dim
def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int:
if not len(placements) == 1:
raise ValueError(
"目前 loss_parallel() 仅支持在一维 DeviceMesh 上的输入。"
)
if not placements[0].is_shard(dim):
raise ValueError(
f"loss_parallel() 应该仅在输入张量在维度 {dim} 上被分片时启用。"
)
return 0
def _cast_to_dtensor(
tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh
) -> DTensor:
if isinstance(tensor, DTensor):
if tensor.placements == placements:
return tensor
else:
raise RuntimeError(f"期望 {placements} 但得到 {tensor.placements}。")
elif isinstance(tensor, torch.Tensor):
return DTensor.from_local(
tensor, device_mesh=mesh, placements=placements, run_check=False
)
else:
raise TypeError(f"不支持的类型 {type(tensor)}")
def _propagate_tensor_meta(
op_call: torch._ops.OpOverload,
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> TensorMeta:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
op_info.schema
)
if isinstance(tensor_meta, TensorMeta):
return tensor_meta
elif isinstance(<span class="n