torch.distributions.exp_family 的源代码
import torch
from torch.distributions.distribution import Distribution
__all__ = ["ExponentialFamily"]
[docs]class ExponentialFamily(Distribution):
r"""
ExponentialFamily 是指数族概率分布的抽象基类,其概率质量/密度函数的形式定义如下
.. math::
p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
其中 :math:`\theta` 表示自然参数,:math:`t(x)` 表示充分统计量,
:math:`F(\theta)` 是给定族的对数归一化函数,:math:`k(x)` 是载体测度。
注意:
该类是 `Distribution` 类与属于指数族的分布之间的中间层,主要用于检查 `.entropy()` 和解析 KL 散度方法的正确性。
我们使用该类通过 AD 框架和 Bregman 散度(感谢:Frank Nielsen 和 Richard Nock,指数族的熵和交叉熵)来计算熵和 KL 散度。
"""
@property
def _natural_params(self):
"""
自然参数的抽象方法。返回基于分布的张量元组
"""
raise NotImplementedError
def _log_normalizer(self, *natural_params):
"""
对数归一化函数的抽象方法。返回基于分布和输入的对数归一化
"""
raise NotImplementedError
@property
def _mean_carrier_measure(self):
"""
期望载体测度的抽象方法,这是计算熵所必需的。
"""
raise NotImplementedError
[docs] def entropy(self):
"""
使用对数归一化的 Bregman 散度计算熵的方法。
"""
result = -self._mean_carrier_measure
nparams = [p.detach().requires_grad_() for p in self._natural_params]
lg_normal = self._log_normalizer(*nparams)
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
result += lg_normal
for np, g in zip(nparams, gradients):
result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
return result