torch.distributions.constraint_registry 的源代码
r"""
PyTorch 提供了两个全局的 :class:`ConstraintRegistry` 对象,它们将
:class:`~torch.distributions.constraints.Constraint` 对象链接到
:class:`~torch.distributions.transforms.Transform` 对象。这两个对象都
接受约束并返回变换,但它们在双射性上有不同的保证。
1. ``biject_to(constraint)`` 查找一个从 ``constraints.real`` 到给定 ``constraint`` 的双射
:class:`~torch.distributions.transforms.Transform`。返回的变换保证具有
``.bijective = True`` 并且应该实现 ``.log_abs_det_jacobian()``。
2. ``transform_to(constraint)`` 查找一个不一定双射的
:class:`~torch.distributions.transforms.Transform` 从 ``constraints.real``
到给定的 ``constraint``。返回的变换不保证实现 ``.log_abs_det_jacobian()``。
``transform_to()`` 注册表对于在概率分布的约束参数上执行无约束优化非常有用,这些参数由每个分布的 ``.arg_constraints`` 字典指示。这些变换通常
过度参数化一个空间以避免旋转;因此,它们更适合像 Adam 这样的坐标优化算法::
loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()
``biject_to()`` 注册表对于哈密顿蒙特卡罗非常有用,其中来自具有约束 ``.support`` 的概率分布的样本在无约束空间中传播,并且算法通常是旋转
不变的。::
dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()
.. 注意::
一个 ``transform_to`` 和 ``biject_to`` 不同的例子是
``constraints.simplex``:``transform_to(constraints.simplex)`` 返回一个
:class:`~torch.distributions.transforms.SoftmaxTransform`,它简单地
对其输入进行指数化和归一化;这是一个廉价且主要是
坐标操作,适用于像 SVI 这样的算法。相比之下,``biject_to(constraints.simplex)`` 返回一个
:class:`~torch.distributions.transforms.StickBreakingTransform`,它
将其输入双射到一个维度更少的空间;这是一个更
昂贵且数值上不太稳定的变换,但对于像 HMC 这样的算法是必需的。
``biject_to`` 和 ``transform_to`` 对象可以通过用户定义的
约束和变换使用它们的 ``.register()`` 方法进行扩展,无论是作为
单例约束的函数::
transform_to.register(my_constraint, my_transform)
或作为参数化约束的装饰器::
@transform_to.register(MyConstraintClass)
def my_factory(constraint):
assert isinstance(constraint, MyConstraintClass)
return MyTransform(constraint.param1, constraint.param2)
您可以通过创建一个新的 :class:`ConstraintRegistry` 对象来创建自己的注册表。
"""
import numbers
from torch.distributions import constraints, transforms
__all__ = [
"ConstraintRegistry",
"biject_to",
"transform_to",
]
[docs]class ConstraintRegistry:
"""
注册表,用于将约束链接到变换。
"""
def __init__(self):
self._registry = {}
super().__init__()
[docs] def register(self, constraint, factory=None):
"""
在此注册表中注册一个 :class:`~torch.distributions.constraints.Constraint`
子类。用法::
@my_registry.register(MyConstraintClass)
def construct_transform(constraint):
assert isinstance(constraint, MyConstraint)
return MyTransform(constraint.arg_constraints)
参数:
constraint (子类或实例):
一个 :class:`~torch.distributions.constraints.Constraint` 的子类,或
所需类的单例对象。
factory (Callable): 一个可调用对象,输入一个约束对象并返回
:class:`~torch.distributions.transforms.Transform` 对象。
"""
# 支持作为装饰器使用。
if factory is None:
return lambda factory: self.register(constraint, factory)
# 支持对单例实例调用。
if isinstance(constraint, constraints.Constraint):
constraint = type(constraint)
if not isinstance(constraint, type) or not issubclass(
constraint, constraints.Constraint
):
raise TypeError(
f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
)
self._registry[constraint] = factory
return factory
def __call__(self, constraint):
"""
查找给定约束对象的变换到约束空间。用法::
constraint = Normal.arg_constraints['scale']
scale = transform_to(constraint)(torch.zeros(1)) # 约束
u = transform_to(constraint).inv(scale) # 无约束
参数:
constraint (:class:`~torch.distributions.constraints.Constraint`):
一个约束对象。
返回:
一个 :class:`~torch.distributions.transforms.Transform` 对象。
引发:
`NotImplementedError` 如果没有注册变换。
"""
# 按 Constraint 子类查找。
try:
factory = self._registry[type(constraint)]
except KeyError:
raise NotImplementedError(
f"Cannot transform {type(constraint).__name__} constraints"
) from None
return factory(constraint)
biject_to = ConstraintRegistry()
transform_to = ConstraintRegistry()
################################################################################
# 注册表
################################################################################
@biject_to.register(constraints.real)
@transform_to.register(constraints.real)
def _transform_to_real(constraint):
return transforms.identity_transform
@biject_to.register(constraints.independent)
def _biject_to_independent(constraint):
base_transform = biject_to(constraint.base_constraint)
return transforms.IndependentTransform(
base_transform<span class="p