Shortcuts

torch.distributions.exp_family 的源代码

import torch
from torch.distributions.distribution import Distribution

__all__ = ["ExponentialFamily"]


[docs]class ExponentialFamily(Distribution): r""" ExponentialFamily 是指数族概率分布的抽象基类,其概率质量/密度函数的形式定义如下 .. math:: p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) 其中 :math:`\theta` 表示自然参数,:math:`t(x)` 表示充分统计量, :math:`F(\theta)` 是给定族的对数归一化函数,:math:`k(x)` 是载体测度。 注意: 该类是 `Distribution` 类与属于指数族的分布之间的中间层,主要用于检查 `.entropy()` 和解析 KL 散度方法的正确性。 我们使用该类通过 AD 框架和 Bregman 散度(感谢:Frank Nielsen 和 Richard Nock,指数族的熵和交叉熵)来计算熵和 KL 散度。 """ @property def _natural_params(self): """ 自然参数的抽象方法。返回基于分布的张量元组 """ raise NotImplementedError def _log_normalizer(self, *natural_params): """ 对数归一化函数的抽象方法。返回基于分布和输入的对数归一化 """ raise NotImplementedError @property def _mean_carrier_measure(self): """ 期望载体测度的抽象方法,这是计算熵所必需的。 """ raise NotImplementedError
[docs] def entropy(self): """ 使用对数归一化的 Bregman 散度计算熵的方法。 """ result = -self._mean_carrier_measure nparams = [p.detach().requires_grad_() for p in self._natural_params] lg_normal = self._log_normalizer(*nparams) gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) result += lg_normal for np, g in zip(nparams, gradients): result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) return result
优云智算