Shortcuts

torch.distributed.tensor.parallel.loss 的源代码

# 版权所有 (c) Meta Platforms, Inc. 及其附属公司
import contextlib
from typing import cast, Dict, Optional, Tuple

import torch
import torch._prims_common as utils
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch import Tensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
from torch.distributed._tensor.ops.math_ops import (
    _skip_dim,
    Reduction,
    replicate_reduction_dims,
)
from torch.distributed._tensor.placement_types import Placement, TensorMeta
from torch.distributed.device_mesh import DeviceMesh

aten = torch.ops.aten


__all__ = ["loss_parallel"]


[docs]@contextlib.contextmanager def loss_parallel(): """ 一个上下文管理器,启用损失并行计算,当输入在类别维度上被分片时,可以执行高效的并行损失计算。 目前仅支持交叉熵损失。 在此上下文管理器中,可以像往常一样使用 :func:`~torch.nn.functional.cross_entropy` 或 :class:`~torch.nn.CrossEntropyLoss`,对输入参数有以下假设。 相应的 ``backward()`` 调用(如果有)也需要在此上下文管理器下进行。 参数: input (:class:`DTensor`): 输入的 logits。假设在类别维度上被分片。 target (Union[:class:`torch.Tensor`, :class:`DTensor`]): 必须是真实的类别索引(目前不支持类别概率)。 假设在 ``DeviceMesh`` 上被复制。 weight (Union[:class:`torch.Tensor`, :class:`DTensor`], 可选): 如果给出,假设在 ``DeviceMesh`` 上被复制。 label_smoothing: 目前不支持。 返回: 一个被复制的 :class:`DTensor`。 示例: 这里手动创建了一个分片的 DTensor 来展示用法。 在实践中,它通常是一个 TP 模块的输出。 >>> # xdoctest: +SKIP("distributed") >>> from torch.distributed.tensor.parallel import loss_parallel >>> from torch.distributed.device_mesh import init_device_mesh >>> ... >>> device_mesh = init_device_mesh("cuda", (8,)) >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) >>> target = torch.randint(16, (4,), device="cuda") >>> with loss_parallel(): >>> loss = F.cross_entropy(dist_input, target, reduction="mean") >>> loss.backward() >>> ... """ _enable_custom_loss_ops() yield _disable_custom_loss_ops()
# 目前只需要支持一维的 DeviceMesh;通常返回 # 具有 placements[mesh_dim].is_shard(dim) 的 mesh_dim def _find_all_reduce_mesh_dim(placements: Tuple[Placement, ...], dim: int) -> int: if not len(placements) == 1: raise ValueError( "目前 loss_parallel() 仅支持在一维 DeviceMesh 上的输入。" ) if not placements[0].is_shard(dim): raise ValueError( f"loss_parallel() 应该仅在输入张量在维度 {dim} 上被分片时启用。" ) return 0 def _cast_to_dtensor( tensor, placements: Tuple[Placement, ...], mesh: DeviceMesh ) -> DTensor: if isinstance(tensor, DTensor): if tensor.placements == placements: return tensor else: raise RuntimeError(f"期望 {placements} 但得到 {tensor.placements}。") elif isinstance(tensor, torch.Tensor): return DTensor.from_local( tensor, device_mesh=mesh, placements=placements, run_check=False ) else: raise TypeError(f"不支持的类型 {type(tensor)}") def _propagate_tensor_meta( op_call: torch._ops.OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object], ) -> TensorMeta: op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta( op_info.schema ) if isinstance(tensor_meta, TensorMeta): return tensor_meta elif isinstance(<span class="n
优云智算