Shortcuts

torch.distributions.transformed_distribution 的源代码

from typing import Dict

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.independent import Independent
from torch.distributions.transforms import ComposeTransform, Transform
from torch.distributions.utils import _sum_rightmost

__all__ = ["TransformedDistribution"]


[docs]class TransformedDistribution(Distribution): r""" 分布类的扩展,它将一系列变换应用于基础分布。设f为应用的变换的组合:: X ~ 基础分布 Y = f(X) ~ TransformedDistribution(基础分布, f) log p(Y) = log p(X) + log |det (dX/dY)| 请注意,:class:`TransformedDistribution`的``.event_shape``是其基础分布和变换的最大形状,因为变换可以引入事件之间的相关性。 :class:`TransformedDistribution`的使用示例如下:: # 构建一个Logistic分布 # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) 基础分布 = Uniform(0, 1) 变换 = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(基础分布, 变换) 有关更多示例,请查看以下实现的实现: :class:`~torch.distributions.gumbel.Gumbel`, :class:`~torch.distributions.half_cauchy.HalfCauchy`, :class:`~torch.distributions.half_normal.HalfNormal`, :class:`~torch.distributions.log_normal.LogNormal`, :class:`~torch.distributions.pareto.Pareto`, :class:`~torch.distributions.weibull.Weibull`, :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` 和 :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` """ arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, transforms, validate_args=None): if isinstance(transforms, Transform): self.transforms = [ transforms, ] elif isinstance(transforms, list): if not all(isinstance(t, Transform) for t in transforms): raise ValueError( "transforms must be a Transform or a list of Transforms" ) self.transforms = transforms else: raise ValueError( f"transforms must be a Transform or list, but was {transforms}" ) # 根据变换调整基础分布的形状。 base_shape = base_distribution.batch_shape + base_distribution.event_shape base_event_dim = len(base_distribution.event_shape) transform = ComposeTransform(self.transforms) if len(base_shape) < transform.domain.event_dim: raise ValueError( "base_distribution needs to have shape with size at least {}, but got {}.".format( transform.domain.event_dim, base_shape ) ) forward_shape = transform.forward_shape(base_shape) expanded_base_shape = transform.inverse_shape(forward_shape) if base_shape != expanded_base_shape: base_batch_shape = expanded_base_shape[ : len(expanded_base_shape) - base_event_dim ] base_distribution = base_distribution.expand(base_batch_shape) reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim if reinterpreted_batch_ndims > 0: base_distribution = Independent( base_distribution, reinterpreted_batch_ndims ) self.base_dist = base_distribution # 计算形状。 transform_change_in_event_dim = ( transform.codomain.event_dim - transform.domain.event_dim ) event_dim = max( transform.codomain.event_dim, # 变换是耦合的 base_event_dim + transform_change_in_event_dim, # 基础分布是耦合的 ) assert len(forward_shape) >= event_dim cut = len(forward_shape) - event_dim batch_shape = forward_shape[:cut] event_shape = forward_shape[cut:] super().__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self.<span class="
优云智算