Shortcuts

torch.distributions.continuous_bernoulli 的源代码

import math
from numbers import Number

import torch
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import (
    broadcast_all,
    clamp_probs,
    lazy_property,
    logits_to_probs,
    probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits

__all__ = ["ContinuousBernoulli"]


[docs]class ContinuousBernoulli(ExponentialFamily): r""" 创建一个由 :attr:`probs` 或 :attr:`logits` 参数化的连续伯努利分布(但不能同时使用两者)。 该分布支持在 [0, 1] 区间内,并由 'probs'(在 (0,1) 内)或 'logits'(实值)参数化。注意,与伯努利分布不同,'probs' 不对应于概率,'logits' 也不对应于对数几率,但由于与伯努利分布的相似性,使用了相同的名称。详见 [1]。 示例:: >>> # xdoctest: +IGNORE_WANT("非确定性") >>> m = ContinuousBernoulli(torch.tensor([0.3])) >>> m.sample() tensor([ 0.2538]) 参数: probs (Number, Tensor): (0,1) 值的参数 logits (Number, Tensor): 实值参数,其 sigmoid 与 'probs' 匹配 [1] 连续伯努利分布:修正变分自编码器中的一个普遍错误,Loaiza-Ganem G 和 Cunningham JP,NeurIPS 2019。 https://arxiv.org/abs/1907.06845 """ arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} support = constraints.unit_interval _mean_carrier_measure = 0 has_rsample = True def __init__( self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None ): if (probs is None) == (logits is None): raise ValueError( "Either `probs` or `logits` must be specified, but not both." ) if probs is not None: is_scalar = isinstance(probs, Number) (self.probs,) = broadcast_all(probs) # 如果必要,在此验证 'probs',因为之后会为数值稳定性将其钳制在接近 0 和 1 的位置;否则钳制后的 'probs' 总是会通过验证 if validate_args is not None: if not self.arg_constraints["probs"].check(self.probs).all(): raise ValueError("The parameter probs has invalid values") self.probs = clamp_probs(self.probs) else: is_scalar = isinstance(logits, Number) (self.logits,) = broadcast_all(logits) self._param = self.probs if probs is not None else self.logits if is_scalar: batch_shape = torch.Size() else: batch_shape = self._param.size() self._lims = lims super().__init__(batch_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(ContinuousBernoulli, _instance) new._lims = self._lims batch_shape = torch.Size(batch_shape) if "probs" in self.__dict__: new.probs = self.probs.expand(batch_shape) new._param = new.probs if "logits" in self.__dict__: new.logits = self.logits.expand(batch_shape) new._param = new.logits super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False) new._validate_args = self._validate_args return new
def _new<span class="p
优云智算