torch.distributions.lowrank_multivariate_normal 的源代码
import math
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
from torch.distributions.utils import _standard_normal, lazy_property
__all__ = ["LowRankMultivariateNormal"]
def _batch_capacitance_tril(W, D):
r"""
计算一批矩阵 :math:`W` 和一批向量 :math:`D` 的 :math:`I + W.T @ inv(D) @ W` 的 Cholesky 分解。
"""
m = W.size(-1)
Wt_Dinv = W.mT / D.unsqueeze(-2)
K = torch.matmul(Wt_Dinv, W).contiguous()
K.view(-1, m * m)[:, :: m + 1] += 1 # 将单位矩阵添加到 K
return torch.linalg.cholesky(K)
def _batch_lowrank_logdet(W, D, capacitance_tril):
r"""
使用“矩阵行列式引理”::
log|W @ W.T + D| = log|C| + log|D|,
其中 :math:`C` 是电容矩阵 :math:`I + W.T @ inv(D) @ W`,用于计算对数行列式。
"""
return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
-1
)
def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
r"""
使用“Woodbury 矩阵恒等式”::
inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
其中 :math:`C` 是电容矩阵 :math:`I + W.T @ inv(D) @ W`,用于计算平方 Mahalanobis 距离 :math:`x.T @ inv(W @ W.T + D) @ x`。
"""
Wt_Dinv = W.mT / D.unsqueeze(-2)
Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
mahalanobis_term1 = (x.pow(2) / D).sum(-1)
mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
return mahalanobis_term1 - mahalanobis_term2
[docs]class LowRankMultivariateNormal(Distribution):
r"""
创建一个协方差矩阵具有低秩形式的多元正态分布,由 :attr:`cov_factor` 和 :attr:`cov_diag` 参数化::
covariance_matrix = cov_factor @ cov_factor.T + cov_diag
示例:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
>>> m.sample() # 均值为 `[0,0]`,cov_factor 为 `[[1],[0]]`,cov_diag 为 `[1,1]` 的正态分布
tensor([-0.2102, -0.5429])
参数:
loc (Tensor): 分布的均值,形状为 `batch_shape + event_shape`
cov_factor (Tensor): 协方差矩阵的低秩形式的因子部分,形状为 `batch_shape + event_shape + (rank,)`
cov_diag (Tensor): 协方差矩阵的低秩形式的对角部分,形状为 `batch_shape + event_shape`
注意:
当 `cov_factor.shape[1] << cov_factor.shape[0]` 时,由于 `Woodbury 矩阵恒等式
`_ 和
`矩阵行列式引理 `_,避免了协方差矩阵的行列式和逆的计算。
感谢这些公式,我们只需要计算小尺寸的“电容”矩阵的行列式和逆::
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
"""
arg_constraints = {
"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),
"cov_diag": constraints.independent(constraints.positive, 1),
}
support = constraints.real_vector
has_rsample = True
def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
if loc.dim() < 1:
raise ValueError("loc 必须至少为一维。")
event_shape = loc.shape[-1:]
if cov_factor.dim() < 2:
raise ValueError(
"cov_factor 必须至少为二维,"
"可选的前导批次维度"
<