torch.distributions.constraints 的源代码
r"""
以下约束已实现:
- ``constraints.boolean``
- ``constraints.cat``
- ``constraints.corr_cholesky``
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.greater_than_eq(lower_bound)``
- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.less_than(upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.multinomial``
- ``constraints.nonnegative``
- ``constraints.nonnegative_integer``
- ``constraints.one_hot``
- ``constraints.positive_integer``
- ``constraints.positive``
- ``constraints.positive_semidefinite``
- ``constraints.positive_definite``
- ``constraints.real_vector``
- ``constraints.real``
- ``constraints.simplex``
- ``constraints.symmetric``
- ``constraints.stack``
- ``constraints.square``
- ``constraints.symmetric``
- ``constraints.unit_interval``
"""
import torch
__all__ = [
"Constraint",
"boolean",
"cat",
"corr_cholesky",
"dependent",
"dependent_property",
"greater_than",
"greater_than_eq",
"independent",
"integer_interval",
"interval",
"half_open_interval",
"is_dependent",
"less_than",
"lower_cholesky",
"lower_triangular",
"multinomial",
"nonnegative",
"nonnegative_integer",
"one_hot",
"positive",
"positive_semidefinite",
"positive_definite",
"positive_integer",
"real",
"real_vector",
"simplex",
"square",
"stack",
"symmetric",
"unit_interval",
]
[docs]class Constraint:
"""
约束的抽象基类。
约束对象表示变量有效的区域,
例如,变量可以在其中进行优化。
属性:
is_discrete (bool): 约束空间是否为离散的。
默认为 False。
event_dim (int): 定义事件的右端维度数量。
:meth:`check` 方法将删除这些维度
以计算有效性。
"""
is_discrete = False # 默认为连续。
event_dim = 0 # 默认为单变量。
[docs] def check(self, value):
"""
返回一个字节张量 ``sample_shape + batch_shape``,指示
值中的每个事件是否满足此约束。
"""
raise NotImplementedError
def __repr__(self):
return self.__class__.__name__[1:] + "()"
class _Dependent(Constraint):
"""
占位符,用于支持依赖于其他变量的变量。
这些变量不遵循简单的坐标约束。
参数:
is_discrete (bool): 可选的 ``.is_discrete`` 值,如果可以静态计算。
如果未提供,访问 ``.is_discrete`` 属性将引发 NotImplementedError。
event_dim (int): 可选的 ``.event_dim`` 值,如果可以静态计算。
如果未提供,访问 ``.event_dim`` 属性将引发 NotImplementedError。
"""
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
self._is_discrete = is_discrete
self._event_dim = event_dim
super().__init__()
@property
def is_discrete(self):
if self._is_discrete is NotImplemented:
raise NotImplementedError(".is_discrete 无法静态确定")
return self._is_discrete
@property
def event_dim(self):
if self._event_dim is NotImplemented:
raise NotImplementedError(".event_dim 无法静态确定")
return self._event_dim
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
"""
支持语法以自定义静态属性::
constraints.dependent(is_discrete=True, event_dim=1)
"""
if is_discrete is NotImplemented:
is_discrete = self._is_discrete
if event_dim is NotImplemented:
event_dim = self._event_dim
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
def check(self, x):
raise ValueError("无法确定依赖约束的有效性")
def is_dependent(constraint):
return isinstance(constraint, _Dependent)
class _DependentProperty(property, _Dependent):
"""
装饰器,扩展 @property 以在类上调用时像 `Dependent` 约束一样,
在对象上调用时像属性一样。
示例::
class Uniform(Distribution):
def __init__(self, low, high):
self.low = low
self.high = high
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return constraints.interval(self.low, self