Shortcuts

torch.distributions.utils 的源代码

```html
from functools import update_wrapper
from numbers import Number
from typing import Any, Dict

import torch
import torch.nn.functional as F
from torch.overrides import is_tensor_like

euler_constant = 0.57721566490153286060  # 欧拉-马歇罗尼常数

__all__ = [
    "broadcast_all",
    "logits_to_probs",
    "clamp_probs",
    "probs_to_logits",
    "lazy_property",
    "tril_matrix_to_vec",
    "vec_to_tril_matrix",
]


def broadcast_all(*values):
    r"""
    给定一个包含数字的值列表(可能包含数字),返回一个列表,其中每个值根据以下规则进行广播:
      - `torch.*Tensor` 实例按照 :ref:`_broadcasting-semantics` 进行广播。
      - numbers.Number 实例(标量)被上转为与传递给 `values` 的第一个张量具有相同大小和类型的张量。如果所有值都是标量,则它们被上转为标量张量。

    参数:
        values (list of `numbers.Number`, `torch.*Tensor` 或实现 __torch_function__ 的对象)

    抛出:
        ValueError: 如果任何值不是 `numbers.Number` 实例、`torch.*Tensor` 实例或实现 __torch_function__ 的实例
    """
    if not all(is_tensor_like(v) or isinstance(v, Number) for v in values):
        raise ValueError(
            "输入参数必须是 numbers.Number 实例、torch.Tensor 实例或实现 __torch_function__ 的对象。"
        )
    if not all(is_tensor_like(v) for v in values):
        options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
        for value in values:
            if isinstance(value, torch.Tensor):
                options = dict(dtype=value.dtype, device=value.device)
                break
        new_values = [
            v if is_tensor_like(v) else torch.tensor(v, **options) for v in values
        ]
        return torch.broadcast_tensors(*new_values)
    return torch.broadcast_tensors(*values)


def _standard_normal(shape, dtype, device):
    if torch._C._get_tracing_state():
        # [JIT WORKAROUND] 缺少对 .normal_() 的支持
        return torch.normal(
            torch.zeros(shape, dtype=dtype, device=device),
            torch.ones(shape, dtype=dtype, device=device),
        )
    return torch.empty(shape, dtype=dtype, device=device).normal_()


def _sum_rightmost(value, dim):
    r"""
    对给定张量的最右边的 ``dim`` 个维度进行求和。

    参数:
        value (Tensor): 一个维度至少为 ``dim`` 的张量。
        dim (int): 要求和的最右边的维度数量。
    """
    if dim == 0:
        return value
    required_shape = value.shape[:-dim] + (-1,)
    return value.reshape(required_shape).sum(-1)


def logits_to_probs(logits, is_binary=False):
    r"""
    将 logits 张量转换为概率。请注意,对于二元情况,每个值表示对数几率,而对于多维情况,最后一个维度上的值表示事件的对数概率(可能未归一化)。
    """
    if is_binary:
        return torch.sigmoid(logits)
    return F.softmax(logits, dim=-1)


def clamp_probs(probs):
    eps = torch.finfo(probs.dtype).eps
    return probs.clamp(min=eps, max=1 - eps)


def probs_to_logits(probs, is_binary=False):
    r"""
    将概率张量转换为 logits。对于二元情况,这表示索引为 `1` 的事件发生的概率。对于多维情况,最后一个维度上的值表示每个事件发生的概率。
    """
    ps_clamped = clamp_probs</