Shortcuts

torch.distributions.lowrank_multivariate_normal 的源代码

import math

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
from torch.distributions.utils import _standard_normal, lazy_property

__all__ = ["LowRankMultivariateNormal"]


def _batch_capacitance_tril(W, D):
    r"""
    计算一批矩阵 :math:`W` 和一批向量 :math:`D` 的 :math:`I + W.T @ inv(D) @ W` 的 Cholesky 分解。
    """
    m = W.size(-1)
    Wt_Dinv = W.mT / D.unsqueeze(-2)
    K = torch.matmul(Wt_Dinv, W).contiguous()
    K.view(-1, m * m)[:, :: m + 1] += 1  # 将单位矩阵添加到 K
    return torch.linalg.cholesky(K)


def _batch_lowrank_logdet(W, D, capacitance_tril):
    r"""
    使用“矩阵行列式引理”::
        log|W @ W.T + D| = log|C| + log|D|,
    其中 :math:`C` 是电容矩阵 :math:`I + W.T @ inv(D) @ W`,用于计算对数行列式。
    """
    return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(
        -1
    )


def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
    r"""
    使用“Woodbury 矩阵恒等式”::
        inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
    其中 :math:`C` 是电容矩阵 :math:`I + W.T @ inv(D) @ W`,用于计算平方 Mahalanobis 距离 :math:`x.T @ inv(W @ W.T + D) @ x`。
    """
    Wt_Dinv = W.mT / D.unsqueeze(-2)
    Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
    mahalanobis_term1 = (x.pow(2) / D).sum(-1)
    mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
    return mahalanobis_term1 - mahalanobis_term2


[docs]class LowRankMultivariateNormal(Distribution): r""" 创建一个协方差矩阵具有低秩形式的多元正态分布,由 :attr:`cov_factor` 和 :attr:`cov_diag` 参数化:: covariance_matrix = cov_factor @ cov_factor.T + cov_diag 示例: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) >>> m.sample() # 均值为 `[0,0]`,cov_factor 为 `[[1],[0]]`,cov_diag 为 `[1,1]` 的正态分布 tensor([-0.2102, -0.5429]) 参数: loc (Tensor): 分布的均值,形状为 `batch_shape + event_shape` cov_factor (Tensor): 协方差矩阵的低秩形式的因子部分,形状为 `batch_shape + event_shape + (rank,)` cov_diag (Tensor): 协方差矩阵的低秩形式的对角部分,形状为 `batch_shape + event_shape` 注意: 当 `cov_factor.shape[1] << cov_factor.shape[0]` 时,由于 `Woodbury 矩阵恒等式 `_ 和 `矩阵行列式引理 `_,避免了协方差矩阵的行列式和逆的计算。 感谢这些公式,我们只需要计算小尺寸的“电容”矩阵的行列式和逆:: capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor """ arg_constraints = { "loc": constraints.real_vector, "cov_factor": constraints.independent(constraints.real, 2), "cov_diag": constraints.independent(constraints.positive, 1), } support = constraints.real_vector has_rsample = True def __init__(self, loc, cov_factor, cov_diag, validate_args=None): if loc.dim() < 1: raise ValueError("loc 必须至少为一维。") event_shape = loc.shape[-1:] if cov_factor.dim() < 2: raise ValueError( "cov_factor 必须至少为二维," "可选的前导批次维度" <
优云智算