使用JAX和Numba进行更快的采样#

PyMC 可以通过 PyTensor 将其模型编译为各种执行后端,包括:

  • C

  • JAX

  • Numba

默认情况下,PyMC 使用的是 C 后端,然后由基于 Python 的采样器调用。

然而,通过编译到其他后端,我们可以使用用其他语言编写的采样器,这些采样器调用PyMC模型而不会产生任何Python开销。

对于JAX后端,有可用的NumPyro和BlackJAX NUTS采样器。要使用这些采样器,您必须安装numpyroblackjax。这两个都可以通过conda/mamba获得:mamba install -c conda-forge numpyro blackjax

对于Numba后端,有一个用Rust编写的Nutpie采样器。要使用此采样器,您需要安装nutpiemamba install -c conda-forge nutpie

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm

rng = np.random.default_rng(seed=42)
print(f"Running on PyMC v{pm.__version__}")
Running on PyMC v5.6.0
%config InlineBackend.figure_format = 'retina'
az.style.use("arviz-darkgrid")

我们将使用一个简单的概率PCA模型作为我们的例子。

def build_toy_dataset(N, D, K, sigma=1):
    x_train = np.zeros((D, N))
    w = rng.normal(
        0.0,
        2.0,
        size=(D, K),
    )
    z = rng.normal(0.0, 1.0, size=(K, N))
    mean = np.dot(w, z)
    for d in range(D):
        for n in range(N):
            x_train[d, n] = rng.normal(mean[d, n], sigma)

    print("True principal axes:")
    print(w)
    return x_train


N = 5000  # number of data points
D = 2  # data dimensionality
K = 1  # latent dimensionality

data = build_toy_dataset(N, D, K)
True principal axes:
[[ 0.60943416]
 [-2.07996821]]
plt.scatter(data[0, :], data[1, :], color="blue", alpha=0.1)
plt.axis([-10, 10, -10, 10])
plt.title("Simulated data set")
Text(0.5, 1.0, 'Simulated data set')
../../../_images/a6200c084641fcdbddb8dcd934b79899a09357b22b12d743e338ae95127cf922.png
with pm.Model() as PPCA:
    w = pm.Normal("w", mu=0, sigma=2, shape=[D, K], transform=pm.distributions.transforms.Ordered())
    z = pm.Normal("z", mu=0, sigma=1, shape=[N, K])
    x = pm.Normal("x", mu=w.dot(z.T), sigma=1, shape=[D, N], observed=data)

使用Python NUTS采样器进行采样#

%%time
with PPCA:
    idata_pymc = pm.sample()
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, z]
100.00% [8000/8000 00:28<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 29 seconds.
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/arviz/utils.py:184: NumbaDeprecationWarning: The 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
  numba_fn = numba.jit(**self.kwargs)(self.function)
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
CPU times: user 19.7 s, sys: 971 ms, total: 20.7 s
Wall time: 47.6 s

使用 NumPyro JAX NUTS 采样器进行采样#

%%time
with PPCA:
    idata_numpyro = pm.sample(nuts_sampler="numpyro", progressbar=False)
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Compiling...
Compilation time =  0:00:00.619901
Sampling...
Sampling time =  0:00:11.469112
Transforming variables...
Transformation time =  0:00:00.118111
CPU times: user 40.5 s, sys: 6.66 s, total: 47.2 s
Wall time: 12.9 s

使用 BlackJAX NUTS 采样器进行采样#

%%time
with PPCA:
    idata_blackjax = pm.sample(nuts_sampler="blackjax")
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
Compiling...
Compilation time =  0:00:00.607693
Sampling...
Sampling time =  0:00:02.132882
Transforming variables...
Transformation time =  0:00:08.410508
CPU times: user 35.4 s, sys: 6.73 s, total: 42.1 s
Wall time: 11.6 s

使用Nutpie Rust NUTS采样器进行采样#

%%time
with PPCA:
    idata_nutpie = pm.sample(nuts_sampler="nutpie")
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
/Users/twiecki/micromamba/envs/pymc5/lib/python3.11/site-packages/pymc/util.py:501: FutureWarning: The tag attribute observations is deprecated. Use model.rvs_to_values[rv] instead
  warnings.warn(
100.00% [8000/8000 00:09<00:00 Chains in warmup: 0, Divergences: 0]
CPU times: user 37.6 s, sys: 3.34 s, total: 41 s
Wall time: 16.1 s

作者#

由Thomas Wiecki于2023年7月撰写

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,arviz,pymc,numpyro,blackjax,nutpie
Last updated: Tue Jul 11 2023

Python implementation: CPython
Python version       : 3.11.4
IPython version      : 8.14.0

pytensor: 2.12.3
arviz   : 0.15.1
pymc    : 5.6.0
numpyro : 0.12.1
blackjax: 0.9.6
nutpie  : 0.6.0

numpy     : 1.24.4
pymc      : 5.6.0
matplotlib: 3.7.1
arviz     : 0.15.1

Watermark: 2.4.3