Shortcuts

torch.distributions.multivariate_normal 的源代码

import math

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _standard_normal, lazy_property

__all__ = ["MultivariateNormal"]


def _batch_mv(bmat, bvec):
    r"""
    执行批量矩阵-向量乘积,具有兼容但不同的批量形状。

    该函数接受输入 `bmat`,包含 :math:`n \times n` 矩阵,以及
    `bvec`,包含长度为 :math:`n` 的向量。

    `bmat` 和 `bvec` 可能具有任意数量的前导维度,这些维度对应
    于批量形状。它们不一定被假定为具有相同的批量形状,
    只是可以广播的形状。
    """
    return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)


def _batch_mahalanobis(bL, bx):
    r"""
    计算平方马哈拉诺比斯距离 :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
    对于一个分解的 :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`。

    接受 `bL` 和 `bx` 的批量。它们不一定被假定为具有相同的批量
    形状,但 `bL` 应该能够广播到 `bx` 的形状。
    """
    n = bx.size(-1)
    bx_batch_shape = bx.shape[:-1]

    # 假设 bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
    # 我们将使 bx 具有形状 (..., 1, j,  i, 1, n) 以应用批量 tri.solve
    bx_batch_dims = len(bx_batch_shape)
    bL_batch_dims = bL.dim() - 2
    outer_batch_dims = bx_batch_dims - bL_batch_dims
    old_batch_dims = outer_batch_dims + bL_batch_dims
    new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
    # 将 bx 重塑为形状 (..., 1, i, j, 1, n)
    bx_new_shape = bx.shape[:outer_batch_dims]
    for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
        bx_new_shape += (sx // sL, sL)
    bx_new_shape += (n,)
    bx = bx.reshape(bx_new_shape)
    # 排列 bx 使其具有形状 (..., 1, j, i, 1, n)
    permute_dims = (
        list(range(outer_batch_dims))
        + list(range(outer_batch_dims, new_batch_dims, 2))
        + list(range(outer_batch_dims + 1, new_batch_dims, 2))
        + [new_batch_dims]
    )
    bx = bx.permute(permute_dims)

    flat_L = bL.reshape(-1, n, n)  # 形状 = b x n x n
    flat_x = bx.reshape(-1, flat_L.size(0), n)  # 形状 = c x b x n
    flat_x_swap = flat_x.permute(1, 2, 0)  # 形状 = b x n x c
    M_swap = (
        torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
    )  # 形状 = b x c
    M = M_swap.t()  # 形状 = c x b

    # 现在我们恢复上述的重塑和排列操作。
    permuted_M = M.reshape(bx.shape[:-1])  # 形状 = (..., 1, j, i, 1)
    permute_inv_dims = list(range(outer_batch_dims))
    for i in range(bL_batch_dims):
        permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
    reshaped_M = permuted_M.permute(permute_inv_dims)  # 形状 = (..., 1, i, j, 1)
    return reshaped_M.reshape(bx_batch_shape)


def _precision_to_scale_tril(P):
    # 参考: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
    Lf = torch<
优云智算