torch.distributions.transforms 的源代码
import functools
import math
import numbers
import operator
import weakref
from typing import List
import torch
import torch.nn.functional as F
from torch.distributions import constraints
from torch.distributions.utils import (
_sum_rightmost,
broadcast_all,
lazy_property,
tril_matrix_to_vec,
vec_to_tril_matrix,
)
from torch.nn.functional import pad, softplus
__all__ = [
"AbsTransform",
"AffineTransform",
"CatTransform",
"ComposeTransform",
"CorrCholeskyTransform",
"CumulativeDistributionTransform",
"ExpTransform",
"IndependentTransform",
"LowerCholeskyTransform",
"PositiveDefiniteTransform",
"PowerTransform",
"ReshapeTransform",
"SigmoidTransform",
"SoftplusTransform",
"TanhTransform",
"SoftmaxTransform",
"StackTransform",
"StickBreakingTransform",
"Transform",
"identity_transform",
]
[docs]class Transform:
"""
可逆变换的抽象类,具有可计算的对数行列式雅可比矩阵。它们主要用于
:class:`torch.distributions.TransformedDistribution`。
缓存对于逆变换代价高昂或数值不稳定的情况很有用。请注意,必须小心处理缓存值
因为自动求导图可能会反转。例如,以下代码在有无缓存的情况下都可以工作::
y = t(x)
t.log_abs_det_jacobian(x, y).backward() # x 将接收梯度。
然而,以下代码在缓存时会因依赖反转而报错::
y = t(x)
z = t.inv(y)
grad(z.sum(), [y]) # 错误,因为 z 是 x
派生类应实现 :meth:`_call` 或 :meth:`_inverse` 中的一个或两个。设置 `bijective=True` 的派生类还应
实现 :meth:`log_abs_det_jacobian`。
参数:
cache_size (int): 缓存大小。如果为零,则不进行缓存。如果为一,
则缓存最新的单个值。仅支持 0 和 1。
属性:
domain (:class:`~torch.distributions.constraints.Constraint`):
表示此变换的有效输入的约束。
codomain (:class:`~torch.distributions.constraints.Constraint`):
表示此变换的有效输出的约束,这些输出是逆变换的输入。
bijective (bool): 此变换是否是双射的。变换 ``t`` 是双射的当且仅当 ``t.inv(t(x)) == x`` 和
``t(t.inv(y)) == y`` 对于域中的每个 ``x`` 和 codomain 中的每个 ``y``。
非双射变换至少应保持较弱的伪逆属性
``t(t.inv(t(x)) == t(x)`` 和 ``t.inv(t(t.inv(y))) == t.inv(y)``。
sign (int or Tensor): 对于双射的一元变换,这应为 +1 或 -1,取决于变换是单调递增还是递减。
"""
bijective = False
domain: constraints.Constraint
codomain: constraints.Constraint
def __init__(self, cache_size=0):
self._cache_size = cache_size
self._inv = None
if cache_size == 0:
pass # 默认行为
elif cache_size == 1:
self._cached_x_y = None, None
else:
raise ValueError("cache_size 必须是 0 或 1")
super().__init__()
def __getstate__(self):
state = self.__dict__.copy()
state["_inv"] = None
return state
@property
def event_dim(self):
if self.domain.event_dim == self.codomain.event_dim:
return self.domain.event_dim
raise ValueError("请使用 .domain.event_dim 或 .codomain.event_dim")
@property
def inv(self):
"""
返回此变换的逆 :class:`Transform`。
这应满足 ``t.inv.inv is t``。
"""
inv = None
if self._inv is not None:
inv = self._inv()
if inv is None:
inv = _InverseTransform(self)
self._inv = weakref.ref(inv)
return inv
@property
def sign(self):
"""
返回雅可比行列式的符号(如果适用)。
通常这仅对双射变换有意义。
"""
raise NotImplementedError
def with_cache(self, cache_size=1):
if self._cache_size == cache_size:
return self
if type(self).__init__ is Transform.__init__:
return type(self)(cache_size=cache_size)
raise NotImplementedError(f<span class="