torch.distributions.continuous_bernoulli 的源代码
import math
from numbers import Number
import torch
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import (
broadcast_all,
clamp_probs,
lazy_property,
logits_to_probs,
probs_to_logits,
)
from torch.nn.functional import binary_cross_entropy_with_logits
__all__ = ["ContinuousBernoulli"]
[docs]class ContinuousBernoulli(ExponentialFamily):
r"""
创建一个由 :attr:`probs` 或 :attr:`logits` 参数化的连续伯努利分布(但不能同时使用两者)。
该分布支持在 [0, 1] 区间内,并由 'probs'(在 (0,1) 内)或 'logits'(实值)参数化。注意,与伯努利分布不同,'probs' 不对应于概率,'logits' 也不对应于对数几率,但由于与伯努利分布的相似性,使用了相同的名称。详见 [1]。
示例::
>>> # xdoctest: +IGNORE_WANT("非确定性")
>>> m = ContinuousBernoulli(torch.tensor([0.3]))
>>> m.sample()
tensor([ 0.2538])
参数:
probs (Number, Tensor): (0,1) 值的参数
logits (Number, Tensor): 实值参数,其 sigmoid 与 'probs' 匹配
[1] 连续伯努利分布:修正变分自编码器中的一个普遍错误,Loaiza-Ganem G 和 Cunningham JP,NeurIPS 2019。
https://arxiv.org/abs/1907.06845
"""
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.unit_interval
_mean_carrier_measure = 0
has_rsample = True
def __init__(
self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None
):
if (probs is None) == (logits is None):
raise ValueError(
"Either `probs` or `logits` must be specified, but not both."
)
if probs is not None:
is_scalar = isinstance(probs, Number)
(self.probs,) = broadcast_all(probs)
# 如果必要,在此验证 'probs',因为之后会为数值稳定性将其钳制在接近 0 和 1 的位置;否则钳制后的 'probs' 总是会通过验证
if validate_args is not None:
if not self.arg_constraints["probs"].check(self.probs).all():
raise ValueError("The parameter probs has invalid values")
self.probs = clamp_probs(self.probs)
else:
is_scalar = isinstance(logits, Number)
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
batch_shape = torch.Size()
else:
batch_shape = self._param.size()
self._lims = lims
super().__init__(batch_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(ContinuousBernoulli, _instance)
new._lims = self._lims
batch_shape = torch.Size(batch_shape)
if "probs" in self.__dict__:
new.probs = self.probs.expand(batch_shape)
new._param = new.probs
if "logits" in self.__dict__:
new.logits = self.logits.expand(batch_shape)
new._param = new.logits
super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
def _new<span class="p