Shortcuts

torch.distributions.transforms 的源代码

import functools
import math
import numbers
import operator
import weakref
from typing import List

import torch
import torch.nn.functional as F
from torch.distributions import constraints
from torch.distributions.utils import (
    _sum_rightmost,
    broadcast_all,
    lazy_property,
    tril_matrix_to_vec,
    vec_to_tril_matrix,
)
from torch.nn.functional import pad, softplus

__all__ = [
    "AbsTransform",
    "AffineTransform",
    "CatTransform",
    "ComposeTransform",
    "CorrCholeskyTransform",
    "CumulativeDistributionTransform",
    "ExpTransform",
    "IndependentTransform",
    "LowerCholeskyTransform",
    "PositiveDefiniteTransform",
    "PowerTransform",
    "ReshapeTransform",
    "SigmoidTransform",
    "SoftplusTransform",
    "TanhTransform",
    "SoftmaxTransform",
    "StackTransform",
    "StickBreakingTransform",
    "Transform",
    "identity_transform",
]


[docs]class Transform: """ 可逆变换的抽象类,具有可计算的对数行列式雅可比矩阵。它们主要用于 :class:`torch.distributions.TransformedDistribution`。 缓存对于逆变换代价高昂或数值不稳定的情况很有用。请注意,必须小心处理缓存值 因为自动求导图可能会反转。例如,以下代码在有无缓存的情况下都可以工作:: y = t(x) t.log_abs_det_jacobian(x, y).backward() # x 将接收梯度。 然而,以下代码在缓存时会因依赖反转而报错:: y = t(x) z = t.inv(y) grad(z.sum(), [y]) # 错误,因为 z 是 x 派生类应实现 :meth:`_call` 或 :meth:`_inverse` 中的一个或两个。设置 `bijective=True` 的派生类还应 实现 :meth:`log_abs_det_jacobian`。 参数: cache_size (int): 缓存大小。如果为零,则不进行缓存。如果为一, 则缓存最新的单个值。仅支持 0 和 1。 属性: domain (:class:`~torch.distributions.constraints.Constraint`): 表示此变换的有效输入的约束。 codomain (:class:`~torch.distributions.constraints.Constraint`): 表示此变换的有效输出的约束,这些输出是逆变换的输入。 bijective (bool): 此变换是否是双射的。变换 ``t`` 是双射的当且仅当 ``t.inv(t(x)) == x`` 和 ``t(t.inv(y)) == y`` 对于域中的每个 ``x`` 和 codomain 中的每个 ``y``。 非双射变换至少应保持较弱的伪逆属性 ``t(t.inv(t(x)) == t(x)`` 和 ``t.inv(t(t.inv(y))) == t.inv(y)``。 sign (int or Tensor): 对于双射的一元变换,这应为 +1 或 -1,取决于变换是单调递增还是递减。 """ bijective = False domain: constraints.Constraint codomain: constraints.Constraint def __init__(self, cache_size=0): self._cache_size = cache_size self._inv = None if cache_size == 0: pass # 默认行为 elif cache_size == 1: self._cached_x_y = None, None else: raise ValueError("cache_size 必须是 0 或 1") super().__init__() def __getstate__(self): state = self.__dict__.copy() state["_inv"] = None return state @property def event_dim(self): if self.domain.event_dim == self.codomain.event_dim: return self.domain.event_dim raise ValueError("请使用 .domain.event_dim 或 .codomain.event_dim") @property def inv(self): """ 返回此变换的逆 :class:`Transform`。 这应满足 ``t.inv.inv is t``。 """ inv = None if self._inv is not None: inv = self._inv() if inv is None: inv = _InverseTransform(self) self._inv = weakref.ref(inv) return inv @property def sign(self): """ 返回雅可比行列式的符号(如果适用)。 通常这仅对双射变换有意义。 """ raise NotImplementedError def with_cache(self, cache_size=1): if self._cache_size == cache_size: return self if type(self).__init__ is Transform.__init__: return type(self)(cache_size=cache_size) raise NotImplementedError(f<span class="
优云智算