Shortcuts

torch.distributions.constraint_registry 的源代码

r"""
PyTorch 提供了两个全局的 :class:`ConstraintRegistry` 对象,它们将
:class:`~torch.distributions.constraints.Constraint` 对象链接到
:class:`~torch.distributions.transforms.Transform` 对象。这两个对象都
接受约束并返回变换,但它们在双射性上有不同的保证。

1. ``biject_to(constraint)`` 查找一个从 ``constraints.real`` 到给定 ``constraint`` 的双射
   :class:`~torch.distributions.transforms.Transform`。返回的变换保证具有
   ``.bijective = True`` 并且应该实现 ``.log_abs_det_jacobian()``。
2. ``transform_to(constraint)`` 查找一个不一定双射的
   :class:`~torch.distributions.transforms.Transform` 从 ``constraints.real``
   到给定的 ``constraint``。返回的变换不保证实现 ``.log_abs_det_jacobian()``。

``transform_to()`` 注册表对于在概率分布的约束参数上执行无约束优化非常有用,这些参数由每个分布的 ``.arg_constraints`` 字典指示。这些变换通常
过度参数化一个空间以避免旋转;因此,它们更适合像 Adam 这样的坐标优化算法::

    loc = torch.zeros(100, requires_grad=True)
    unconstrained = torch.zeros(100, requires_grad=True)
    scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
    loss = -Normal(loc, scale).log_prob(data).sum()

``biject_to()`` 注册表对于哈密顿蒙特卡罗非常有用,其中来自具有约束 ``.support`` 的概率分布的样本在无约束空间中传播,并且算法通常是旋转
不变的。::

    dist = Exponential(rate)
    unconstrained = torch.zeros(100, requires_grad=True)
    sample = biject_to(dist.support)(unconstrained)
    potential_energy = -dist.log_prob(sample).sum()

.. 注意::

    一个 ``transform_to`` 和 ``biject_to`` 不同的例子是
    ``constraints.simplex``:``transform_to(constraints.simplex)`` 返回一个
    :class:`~torch.distributions.transforms.SoftmaxTransform`,它简单地
    对其输入进行指数化和归一化;这是一个廉价且主要是
    坐标操作,适用于像 SVI 这样的算法。相比之下,``biject_to(constraints.simplex)`` 返回一个
    :class:`~torch.distributions.transforms.StickBreakingTransform`,它
    将其输入双射到一个维度更少的空间;这是一个更
    昂贵且数值上不太稳定的变换,但对于像 HMC 这样的算法是必需的。

``biject_to`` 和 ``transform_to`` 对象可以通过用户定义的
约束和变换使用它们的 ``.register()`` 方法进行扩展,无论是作为
单例约束的函数::

    transform_to.register(my_constraint, my_transform)

或作为参数化约束的装饰器::

    @transform_to.register(MyConstraintClass)
    def my_factory(constraint):
        assert isinstance(constraint, MyConstraintClass)
        return MyTransform(constraint.param1, constraint.param2)

您可以通过创建一个新的 :class:`ConstraintRegistry` 对象来创建自己的注册表。
"""

import numbers

from torch.distributions import constraints, transforms

__all__ = [
    "ConstraintRegistry",
    "biject_to",
    "transform_to",
]


[docs]class ConstraintRegistry: """ 注册表,用于将约束链接到变换。 """ def __init__(self): self._registry = {} super().__init__()
[docs] def register(self, constraint, factory=None): """ 在此注册表中注册一个 :class:`~torch.distributions.constraints.Constraint` 子类。用法:: @my_registry.register(MyConstraintClass) def construct_transform(constraint): assert isinstance(constraint, MyConstraint) return MyTransform(constraint.arg_constraints) 参数: constraint (子类或实例): 一个 :class:`~torch.distributions.constraints.Constraint` 的子类,或 所需类的单例对象。 factory (Callable): 一个可调用对象,输入一个约束对象并返回 :class:`~torch.distributions.transforms.Transform` 对象。 """ # 支持作为装饰器使用。 if factory is None: return lambda factory: self.register(constraint, factory) # 支持对单例实例调用。 if isinstance(constraint, constraints.Constraint): constraint = type(constraint) if not isinstance(constraint, type) or not issubclass( constraint, constraints.Constraint ): raise TypeError( f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" ) self._registry[constraint] = factory return factory
def __call__(self, constraint): """ 查找给定约束对象的变换到约束空间。用法:: constraint = Normal.arg_constraints['scale'] scale = transform_to(constraint)(torch.zeros(1)) # 约束 u = transform_to(constraint).inv(scale) # 无约束 参数: constraint (:class:`~torch.distributions.constraints.Constraint`): 一个约束对象。 返回: 一个 :class:`~torch.distributions.transforms.Transform` 对象。 引发: `NotImplementedError` 如果没有注册变换。 """ # 按 Constraint 子类查找。 try: factory = self._registry[type(constraint)] except KeyError: raise NotImplementedError( f"Cannot transform {type(constraint).__name__} constraints" ) from None return factory(constraint)
biject_to = ConstraintRegistry() transform_to = ConstraintRegistry() ################################################################################ # 注册表 ################################################################################ @biject_to.register(constraints.real) @transform_to.register(constraints.real) def _transform_to_real(constraint): return transforms.identity_transform @biject_to.register(constraints.independent) def _biject_to_independent(constraint): base_transform = biject_to(constraint.base_constraint) return transforms.IndependentTransform( base_transform<span class="p
优云智算