torch.distributions.transformed_distribution 的源代码
from typing import Dict
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.independent import Independent
from torch.distributions.transforms import ComposeTransform, Transform
from torch.distributions.utils import _sum_rightmost
__all__ = ["TransformedDistribution"]
[docs]class TransformedDistribution(Distribution):
r"""
分布类的扩展,它将一系列变换应用于基础分布。设f为应用的变换的组合::
X ~ 基础分布
Y = f(X) ~ TransformedDistribution(基础分布, f)
log p(Y) = log p(X) + log |det (dX/dY)|
请注意,:class:`TransformedDistribution`的``.event_shape``是其基础分布和变换的最大形状,因为变换可以引入事件之间的相关性。
:class:`TransformedDistribution`的使用示例如下::
# 构建一个Logistic分布
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
基础分布 = Uniform(0, 1)
变换 = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(基础分布, 变换)
有关更多示例,请查看以下实现的实现:
:class:`~torch.distributions.gumbel.Gumbel`,
:class:`~torch.distributions.half_cauchy.HalfCauchy`,
:class:`~torch.distributions.half_normal.HalfNormal`,
:class:`~torch.distributions.log_normal.LogNormal`,
:class:`~torch.distributions.pareto.Pareto`,
:class:`~torch.distributions.weibull.Weibull`,
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` 和
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
"""
arg_constraints: Dict[str, constraints.Constraint] = {}
def __init__(self, base_distribution, transforms, validate_args=None):
if isinstance(transforms, Transform):
self.transforms = [
transforms,
]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError(
"transforms must be a Transform or a list of Transforms"
)
self.transforms = transforms
else:
raise ValueError(
f"transforms must be a Transform or list, but was {transforms}"
)
# 根据变换调整基础分布的形状。
base_shape = base_distribution.batch_shape + base_distribution.event_shape
base_event_dim = len(base_distribution.event_shape)
transform = ComposeTransform(self.transforms)
if len(base_shape) < transform.domain.event_dim:
raise ValueError(
"base_distribution needs to have shape with size at least {}, but got {}.".format(
transform.domain.event_dim, base_shape
)
)
forward_shape = transform.forward_shape(base_shape)
expanded_base_shape = transform.inverse_shape(forward_shape)
if base_shape != expanded_base_shape:
base_batch_shape = expanded_base_shape[
: len(expanded_base_shape) - base_event_dim
]
base_distribution = base_distribution.expand(base_batch_shape)
reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
if reinterpreted_batch_ndims > 0:
base_distribution = Independent(
base_distribution, reinterpreted_batch_ndims
)
self.base_dist = base_distribution
# 计算形状。
transform_change_in_event_dim = (
transform.codomain.event_dim - transform.domain.event_dim
)
event_dim = max(
transform.codomain.event_dim, # 变换是耦合的
base_event_dim + transform_change_in_event_dim, # 基础分布是耦合的
)
assert len(forward_shape) >= event_dim
cut = len(forward_shape) - event_dim
batch_shape = forward_shape[:cut]
event_shape = forward_shape[cut:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)