随机梯度马尔可夫链蒙特卡罗(SG-MCMC)诊断¶
马尔可夫链蒙特卡罗(MCMC)方法是近似后验分布的有力工具。随机过程,如随机梯度哈密顿蒙特卡罗,能够以更多偏差推断为代价实现快速采样。然而,已经证明标准的MCMC诊断方法无法检测到这些偏差。最近提出的逆多重二次(IMQ)核的核斯坦差异方法(KSD)[Gorham and Mackey, 2017]旨在比较有偏差、精确和确定性的样本序列,这也特别适合并行计算。
在本笔记本中,我们展示了如何评估SG-MCMC样本的质量。
我们创建一个二维多元正态分布的玩具示例。该分布由零均值和协方差矩阵\(\Sigma = P^{T} D P\)参数化,其中\(D\)是对角比例矩阵,\(P\)是某个角度\(r\)的旋转矩阵。
[1]:
from jax import vmap, value_and_grad
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
mu = jnp.zeros(
[
2,
]
)
r = np.pi / 4
D = jnp.array([2.0, 1.0])
P = jnp.array([[jnp.cos(r), jnp.sin(r)], [-jnp.sin(r), jnp.cos(r)]])
sigma = P.T @ jnp.diag(D) @ P
我们创建了一个真实数据集,以及另外两个数据集,分别包含欠分散\(\mathcal{N}(0, \sqrt[5]{\Sigma})\)和过分散\(\mathcal{N}(0, \Sigma^{3})\)的样本。
[2]:
N = 1_000
disp = [1 / 5, 1, 3]
rng = np.random.default_rng(0)
samples = np.array([rng.multivariate_normal(mu, sigma**d, size=N) for d in disp])
来自目标分布的样本数据集(在中间)明显与置信椭圆对齐。
[3]:
titles = ["$\sqrt[5]{\Sigma}$", "$\Sigma$", "$\Sigma^{3}$"]
_, axs = plt.subplots(1, len(samples), sharey=True, figsize=(12, 4))
for i, ax in enumerate(axs.flatten()):
ax.axis("equal")
ax.grid()
ax.scatter(samples[i, :, 0], samples[i, :, 1], alpha=0.3)
for std in range(1, 4):
conf_ell = Ellipse(
xy=mu,
width=D[0] * std,
height=D[1] * std,
angle=np.rad2deg(r),
edgecolor="black",
linestyle="--",
facecolor="none",
)
ax.add_artist(conf_ell)
ax.set_title(titles[i])
plt.show()
使用逆多重二次核的Kernel Stein差异是在一组样本和相应的梯度上计算的。请注意,它具有二次时间复杂度,这使得扩展到大型序列具有挑战性。
[4]:
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic import (
kernel_stein_discrepancy_imq,
)
logpdf = lambda params: stats.multivariate_normal.logpdf(params, mu, sigma)
_, grads = vmap(vmap(value_and_grad(logpdf), 0, 0), 1, 1)(samples)
ksd = vmap(kernel_stein_discrepancy_imq, 0, 0)(samples, grads)
log_ksd = jnp.log10(ksd)
正如预期的那样,从真实分布中采样的数据集中获得了最低的(对数)KSD值。
[5]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
ax.grid()
ax.plot(disp, log_ksd)
ax.set_ylabel("log KSD")
ax.set_xlabel("$\Sigma$")
plt.show()
估计有效样本大小¶
有效样本量(ESS)是一个量化序列中自相关性的指标。直观上,ESS是一个与输入样本具有相同方差的独立同分布样本的大小。典型用途包括计算MCMC估计器的标准误差:
[6]:
from fortuna.prob_model.posterior.sgmcmc.sgmcmc_diagnostic import effective_sample_size
ess = effective_sample_size(samples[0])
variance = jnp.var(samples[0], axis=0)
standard_error = jnp.sqrt(variance / ess)
standard_error
[6]:
Array([0.03653547, 0.03869004], dtype=float32)
请注意,一系列强自相关样本会导致非常低的ESS:
[7]:
print("ESS for no auto-correlation:", effective_sample_size(rng.normal(size=200)))
print(
"ESS for strong auto-correlation:",
effective_sample_size(jnp.arange(200) + rng.normal(size=200)),
)
ESS for no auto-correlation: 200.0
ESS for strong auto-correlation: 2.8766694