torch.distributions.categorical 的源代码
```html
import torch from torch import nan from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits __all__ = ["Categorical"][docs]class Categorical(Distribution): r""" 创建一个由 :attr:`probs` 或 :attr:`logits` 参数化的分类分布(但不能同时使用两者)。 .. 注意:: 它等同于 :func:`torch.multinomial` 从中采样的分布。 样本是从 :math:`\{0, \ldots, K-1\}` 中抽取的整数,其中 `K` 是 ``probs.size(-1)``。 如果 `probs` 是一维的且长度为 `K`,则每个元素是采样该索引类别的相对概率。 如果 `probs` 是 N 维的,则前 N-1 维被视为相对概率向量的批次。 .. 注意:: `probs` 参数必须是非负的、有限的且具有非零和,它将被归一化以沿最后一个维度总和为 1。:attr:`probs` 将返回此归一化值。 `logits` 参数将被解释为未归一化的对数概率,因此可以是任何实数。它同样将被归一化,使得沿最后一个维度的结果概率总和为 1。:attr:`logits` 将返回此归一化值。 另请参阅: :func:`torch.multinomial` 示例:: >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) >>> m.sample() # 0, 1, 2, 3 的概率相等 tensor(3) 参数: probs (Tensor): 事件概率 logits (Tensor): 事件对数概率(未归一化) """ arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} has_enumerate_support = True def __init__(self, probs=None, logits=None, validate_args=None): if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." ) if probs is not None: if probs.dim() < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs / probs.sum(-1, keepdim=True) else: if logits.dim() < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") # Normalize self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) self._param = self.probs if probs is not None else self.logits self._num_events = self._param.size()[-1] batch_shape = ( self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size() ) super().__init__(batch_shape, validate_args=validate_args)[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(Categorical, _instance) batch_shape = torch.Size(batch_shape) param_shape = batch_shape + torch.Size((self._num_events,)) if "probs" in self.__dict__: new.probs = self.probs.expand(param_shape) new._param = new.probs if "logits" in self.__dict__: new.logits = self.logits.expand(param_shape) new._param = new.logits new._num_events = self._num_events super(Categorical, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return newdef _new(self, *args, **kwargs</span