Shortcuts

torch.distributions.distribution 的源代码

```html
import warnings
from typing import Any, Dict, Optional, Tuple

import torch
from torch.distributions import constraints
from torch.distributions.utils import lazy_property
from torch.types import _size

__all__ = ["Distribution"]


[docs]class Distribution: r""" Distribution 是概率分布的抽象基类。 """ has_rsample = False has_enumerate_support = False _validate_args = __debug__
[docs] @staticmethod def set_default_validate_args(value: bool) -> None: """ 设置是否启用验证。 默认行为模仿 Python 的 ``assert`` 语句:默认情况下启用验证,但如果 Python 在优化模式下运行(通过 ``python -O``),则禁用验证。验证可能会很昂贵,因此一旦模型正常工作,您可能希望禁用它。 参数: value (bool): 是否启用验证。 """ if value not in [True, False]: raise ValueError Distribution._validate_args = value
def __init__( self, batch_shape: torch.Size = torch.Size(), event_shape: torch.Size = torch.Size(), validate_args: Optional[bool] = None, ): self._batch_shape = batch_shape self._event_shape = event_shape if validate_args is not None: self._validate_args = validate_args if self._validate_args: try: arg_constraints = self.arg_constraints except NotImplementedError: arg_constraints = {} warnings.warn( f"{self.__class__} 未定义 `arg_constraints`。 " + "请设置 `arg_constraints = {}` 或在初始化分布时 " + "使用 `validate_args=False` 来关闭验证。" ) for param, constraint in arg_constraints.items(): if constraints.is_dependent(constraint): continue # 跳过无法检查的约束 if param not in self.__dict__ and isinstance( getattr(type(self), param), lazy_property ): continue # 跳过延迟构造的参数 value = getattr(self, param) valid = constraint.check(value) if not valid.all(): raise ValueError( f"预期的参数 {param} " f"({type(value).__name__} 的形状 {tuple(value.shape)}) " f"的分布 {repr(self)} " f"满足约束 {repr(constraint)}, " f"但发现无效值:\n{value}" ) super().__init__()
[docs] def expand(self, batch_shape: torch.Size, _instance= None): """ 返回一个新的分布实例(或由派生类提供的现有实例),其批次维度扩展到 `batch_shape`。此方法调用 :class:`~torch.Tensor.expand` 在分布的参数上。因此,这不会为扩展的分布实例分配新的内存。此外,这不会在 `__init__.py` 中重复任何参数检查或参数广播,当实例首次创建时。 参数: batch_shape (torch.Size): 所需的扩展大小。 _instance: 由需要覆盖 `.expand` 的子类提供的新实例。 返回: 新的分布实例,批次维度扩展到 `batch_size`。 """ raise NotImplementedError
@property def batch_shape(self) -> torch.Size: """ 返回参数批次的形状。 """ return self._batch_shape @property def event_shape(self) -> torch.Size: """ 返回单个样本的形状(不包括批次)。 """ return self._event_shape @property def arg_constraints(self) -> Dict[str, constraints.Constraint]: <span class="
优云智算