jax.random.多元正态分布#
- jax.random.multivariate_normal(key, mean, cov, shape=None, dtype=None, method='cholesky')[源代码][源代码]#
使用给定的均值和协方差生成多元正态随机值的示例。
这些值根据概率密度函数返回:
\[f(x;\mu, \Sigma) = (2\pi)^{-k/2} \det(\Sigma)^{-1}e^{-\frac{1}{2}(x - \mu)^T \Sigma^{-1} (x - \mu)}\]其中 \(k\) 是维度,\(\mu\) 是均值(由
mean给出),\(\Sigma\) 是协方差矩阵(由cov给出)。- 参数:
key (KeyArrayLike) – 一个用作随机密钥的 PRNG 密钥。
mean (RealArray) – 形状为
(..., n)的均值向量。cov (RealArray) – 一个形状为
(..., n, n)的正定协方差矩阵。批量形状...必须与mean的形状广播兼容。shape (Shape | None) – 可选,一个非负整数元组,指定结果批次形状;即,结果形状中排除最后一个轴的前缀。必须与
mean.shape[:-1]和cov.shape[:-2]广播兼容。默认值(None)通过将mean和cov的批次形状一起广播来生成结果批次形状。dtype (DTypeLikeFloat | None) – 可选,返回值的浮点数据类型(如果 jax_enable_x64 为 true,则默认为 float64,否则为 float32)。
method (str) – 可选,计算
cov因子的方法。必须是 ‘svd’、’eigh’ 和 ‘cholesky’ 之一。默认 ‘cholesky’。对于奇异协方差矩阵,使用 ‘svd’ 或 ‘eigh’。
- 返回:
一个随机数组,其数据类型为指定的
dtype,形状由shape + mean.shape[-1:]给出,如果shape不是 None,否则为broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]。- 返回类型: