torch.distributions.multivariate_normal 的源代码
import math
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import _standard_normal, lazy_property
__all__ = ["MultivariateNormal"]
def _batch_mv(bmat, bvec):
r"""
执行批量矩阵-向量乘积,具有兼容但不同的批量形状。
该函数接受输入 `bmat`,包含 :math:`n \times n` 矩阵,以及
`bvec`,包含长度为 :math:`n` 的向量。
`bmat` 和 `bvec` 可能具有任意数量的前导维度,这些维度对应
于批量形状。它们不一定被假定为具有相同的批量形状,
只是可以广播的形状。
"""
return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
def _batch_mahalanobis(bL, bx):
r"""
计算平方马哈拉诺比斯距离 :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
对于一个分解的 :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`。
接受 `bL` 和 `bx` 的批量。它们不一定被假定为具有相同的批量
形状,但 `bL` 应该能够广播到 `bx` 的形状。
"""
n = bx.size(-1)
bx_batch_shape = bx.shape[:-1]
# 假设 bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
# 我们将使 bx 具有形状 (..., 1, j, i, 1, n) 以应用批量 tri.solve
bx_batch_dims = len(bx_batch_shape)
bL_batch_dims = bL.dim() - 2
outer_batch_dims = bx_batch_dims - bL_batch_dims
old_batch_dims = outer_batch_dims + bL_batch_dims
new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
# 将 bx 重塑为形状 (..., 1, i, j, 1, n)
bx_new_shape = bx.shape[:outer_batch_dims]
for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
bx_new_shape += (sx // sL, sL)
bx_new_shape += (n,)
bx = bx.reshape(bx_new_shape)
# 排列 bx 使其具有形状 (..., 1, j, i, 1, n)
permute_dims = (
list(range(outer_batch_dims))
+ list(range(outer_batch_dims, new_batch_dims, 2))
+ list(range(outer_batch_dims + 1, new_batch_dims, 2))
+ [new_batch_dims]
)
bx = bx.permute(permute_dims)
flat_L = bL.reshape(-1, n, n) # 形状 = b x n x n
flat_x = bx.reshape(-1, flat_L.size(0), n) # 形状 = c x b x n
flat_x_swap = flat_x.permute(1, 2, 0) # 形状 = b x n x c
M_swap = (
torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2)
) # 形状 = b x c
M = M_swap.t() # 形状 = c x b
# 现在我们恢复上述的重塑和排列操作。
permuted_M = M.reshape(bx.shape[:-1]) # 形状 = (..., 1, j, i, 1)
permute_inv_dims = list(range(outer_batch_dims))
for i in range(bL_batch_dims):
permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
reshaped_M = permuted_M.permute(permute_inv_dims) # 形状 = (..., 1, i, j, 1)
return reshaped_M.reshape(bx_batch_shape)
def _precision_to_scale_tril(P):
# 参考: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
Lf = torch<