torch.distributions.utils 的源代码
```html
from functools import update_wrapper from numbers import Number from typing import Any, Dict import torch import torch.nn.functional as F from torch.overrides import is_tensor_like euler_constant = 0.57721566490153286060 # 欧拉-马歇罗尼常数 __all__ = [ "broadcast_all", "logits_to_probs", "clamp_probs", "probs_to_logits", "lazy_property", "tril_matrix_to_vec", "vec_to_tril_matrix", ] def broadcast_all(*values): r""" 给定一个包含数字的值列表(可能包含数字),返回一个列表,其中每个值根据以下规则进行广播: - `torch.*Tensor` 实例按照 :ref:`_broadcasting-semantics` 进行广播。 - numbers.Number 实例(标量)被上转为与传递给 `values` 的第一个张量具有相同大小和类型的张量。如果所有值都是标量,则它们被上转为标量张量。 参数: values (list of `numbers.Number`, `torch.*Tensor` 或实现 __torch_function__ 的对象) 抛出: ValueError: 如果任何值不是 `numbers.Number` 实例、`torch.*Tensor` 实例或实现 __torch_function__ 的实例 """ if not all(is_tensor_like(v) or isinstance(v, Number) for v in values): raise ValueError( "输入参数必须是 numbers.Number 实例、torch.Tensor 实例或实现 __torch_function__ 的对象。" ) if not all(is_tensor_like(v) for v in values): options: Dict[str, Any] = dict(dtype=torch.get_default_dtype()) for value in values: if isinstance(value, torch.Tensor): options = dict(dtype=value.dtype, device=value.device) break new_values = [ v if is_tensor_like(v) else torch.tensor(v, **options) for v in values ] return torch.broadcast_tensors(*new_values) return torch.broadcast_tensors(*values) def _standard_normal(shape, dtype, device): if torch._C._get_tracing_state(): # [JIT WORKAROUND] 缺少对 .normal_() 的支持 return torch.normal( torch.zeros(shape, dtype=dtype, device=device), torch.ones(shape, dtype=dtype, device=device), ) return torch.empty(shape, dtype=dtype, device=device).normal_() def _sum_rightmost(value, dim): r""" 对给定张量的最右边的 ``dim`` 个维度进行求和。 参数: value (Tensor): 一个维度至少为 ``dim`` 的张量。 dim (int): 要求和的最右边的维度数量。 """ if dim == 0: return value required_shape = value.shape[:-dim] + (-1,) return value.reshape(required_shape).sum(-1) def logits_to_probs(logits, is_binary=False): r""" 将 logits 张量转换为概率。请注意,对于二元情况,每个值表示对数几率,而对于多维情况,最后一个维度上的值表示事件的对数概率(可能未归一化)。 """ if is_binary: return torch.sigmoid(logits) return F.softmax(logits, dim=-1) def clamp_probs(probs): eps = torch.finfo(probs.dtype).eps return probs.clamp(min=eps, max=1 - eps) def probs_to_logits(probs, is_binary=False): r""" 将概率张量转换为 logits。对于二元情况,这表示索引为 `1` 的事件发生的概率。对于多维情况,最后一个维度上的值表示每个事件发生的概率。 """ ps_clamped = clamp_probs</