Shortcuts

torch.distributions.constraints 的源代码

r"""
以下约束已实现:

- ``constraints.boolean``
- ``constraints.cat``
- ``constraints.corr_cholesky``
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.greater_than_eq(lower_bound)``
- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.less_than(upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.multinomial``
- ``constraints.nonnegative``
- ``constraints.nonnegative_integer``
- ``constraints.one_hot``
- ``constraints.positive_integer``
- ``constraints.positive``
- ``constraints.positive_semidefinite``
- ``constraints.positive_definite``
- ``constraints.real_vector``
- ``constraints.real``
- ``constraints.simplex``
- ``constraints.symmetric``
- ``constraints.stack``
- ``constraints.square``
- ``constraints.symmetric``
- ``constraints.unit_interval``
"""

import torch

__all__ = [
    "Constraint",
    "boolean",
    "cat",
    "corr_cholesky",
    "dependent",
    "dependent_property",
    "greater_than",
    "greater_than_eq",
    "independent",
    "integer_interval",
    "interval",
    "half_open_interval",
    "is_dependent",
    "less_than",
    "lower_cholesky",
    "lower_triangular",
    "multinomial",
    "nonnegative",
    "nonnegative_integer",
    "one_hot",
    "positive",
    "positive_semidefinite",
    "positive_definite",
    "positive_integer",
    "real",
    "real_vector",
    "simplex",
    "square",
    "stack",
    "symmetric",
    "unit_interval",
]


[docs]class Constraint: """ 约束的抽象基类。 约束对象表示变量有效的区域, 例如,变量可以在其中进行优化。 属性: is_discrete (bool): 约束空间是否为离散的。 默认为 False。 event_dim (int): 定义事件的右端维度数量。 :meth:`check` 方法将删除这些维度 以计算有效性。 """ is_discrete = False # 默认为连续。 event_dim = 0 # 默认为单变量。
[docs] def check(self, value): """ 返回一个字节张量 ``sample_shape + batch_shape``,指示 值中的每个事件是否满足此约束。 """ raise NotImplementedError
def __repr__(self): return self.__class__.__name__[1:] + "()"
class _Dependent(Constraint): """ 占位符,用于支持依赖于其他变量的变量。 这些变量不遵循简单的坐标约束。 参数: is_discrete (bool): 可选的 ``.is_discrete`` 值,如果可以静态计算。 如果未提供,访问 ``.is_discrete`` 属性将引发 NotImplementedError。 event_dim (int): 可选的 ``.event_dim`` 值,如果可以静态计算。 如果未提供,访问 ``.event_dim`` 属性将引发 NotImplementedError。 """ def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): self._is_discrete = is_discrete self._event_dim = event_dim super().__init__() @property def is_discrete(self): if self._is_discrete is NotImplemented: raise NotImplementedError(".is_discrete 无法静态确定") return self._is_discrete @property def event_dim(self): if self._event_dim is NotImplemented: raise NotImplementedError(".event_dim 无法静态确定") return self._event_dim def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): """ 支持语法以自定义静态属性:: constraints.dependent(is_discrete=True, event_dim=1) """ if is_discrete is NotImplemented: is_discrete = self._is_discrete if event_dim is NotImplemented: event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) def check(self, x): raise ValueError("无法确定依赖约束的有效性") def is_dependent(constraint): return isinstance(constraint, _Dependent) class _DependentProperty(property, _Dependent): """ 装饰器,扩展 @property 以在类上调用时像 `Dependent` 约束一样, 在对象上调用时像属性一样。 示例:: class Uniform(Distribution): def __init__(self, low, high): self.low = low self.high = high @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return constraints.interval(self.low, self
优云智算