如何在 PyMC 中使用 JAX 函数#

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytensor
import pytensor.tensor as pt

from pytensor.graph import Apply, Op
RANDOM_SEED = 104109109
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import pymc.sampling_jax

from pytensor.link.jax.dispatch import jax_funcify
/home/ricardo/miniconda3/envs/pymc-examples/lib/python3.10/site-packages/pytensor/link/jax/dispatch.py:87: UserWarning: JAX omnistaging couldn't be disabled: Disabling of omnistaging is no longer supported in JAX version 0.2.12 and higher: see https://github.com/google/jax/blob/main/design_notes/omnistaging.md.
  warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
/home/ricardo/Documents/Projects/pymc/pymc/sampling_jax.py:36: UserWarning: This module is experimental.
  warnings.warn("This module is experimental.")

简介:PyTensor 及其后端#

PyMC 使用 PyTensor 库来创建和操作概率图。PyTensor 是后端无关的,这意味着它可以利用用不同语言或框架编写的函数,包括纯 Python、NumPy、C、Cython、Numba 和 JAX

所有需要做的就是将这种函数封装在一个 PyTensor Op 中,它强制规定了关于如何处理纯“操作”的输入和输出的特定 API。它还实现了用于可选额外功能的方法,如符号形状推断和自动微分。这在 PyTensor Op 文档 和我们的 使用“黑盒”似然函数 pymc-示例中有详细介绍。

最近,PyTensor 已经能够直接编译到这些语言/框架中的一些,这意味着我们可以将一个完整的 PyTensor 图转换为 JAX 或 NUMBA 的 jitted 函数,而传统上它们只能转换为 Python 或 C。

这有一些有趣的用途,例如使用纯JAX采样器对在PyMC中定义的模型进行采样,如在NumPyroBlackJax中实现的那些。

本笔记本展示了我们如何实现一个新的 PyTensor Op,该操作包装了一个 JAX 函数。

大纲#

  1. 我们从一个与使用“黑箱”似然函数中相似的路径开始,该路径将一个NumPy函数封装在PyTensor的Op中,这次封装的是一个JAX jitted函数。

  2. 然后,我们使 PyTensor 能够“解开”刚刚包装的 JAX 函数,以便整个图可以编译为 JAX。我们利用这一点通过 JAX NumPyro NUTS 采样器来采样我们的 PyMC 模型。

一个激励性的例子:边际HMM#

为了说明的目的,我们将模拟遵循一个简单的隐马尔可夫模型(HMM)的数据,具有3个可能的潜在状态 \(S \in \{0, 1, 2\}\) 和正态发射似然。

\[Y \sim \text{Normal}((S + 1) \cdot \text{signal}, \text{noise})\]

我们的HMM将有一个固定的Categorical概率\(P\)在状态之间切换,这仅取决于上一个状态

\[S_{t+1} \sim \text{Categorical}(P_{S_t})\]

为了完成我们的模型,我们假设每个可能的初始状态 \(S_{t0}\) 都有一个固定的概率 \(P_{t0}\)

\[S_{t0} \sim \text{Categorical}(P_{t0})\]

模拟数据#

让我们根据这个模型生成数据!第一步是为我们模型中的参数设置一些值

# Emission signal and noise parameters
emission_signal_true = 1.15
emission_noise_true = 0.15

p_initial_state_true = np.array([0.9, 0.09, 0.01])

# Probability of switching from state_t to state_t+1
p_transition_true = np.array(
    [
        #    0,   1,   2
        [0.9, 0.09, 0.01],  # 0
        [0.1, 0.8, 0.1],  # 1
        [0.2, 0.1, 0.7],  # 2
    ]
)

# Confirm that we have defined valid probabilities
assert np.isclose(np.sum(p_initial_state_true), 1)
assert np.allclose(np.sum(p_transition_true, axis=-1), 1)
# Let's compute the log of the probalitiy transition matrix for later use
with np.errstate(divide="ignore"):
    logp_initial_state_true = np.log(p_initial_state_true)
    logp_transition_true = np.log(p_transition_true)

logp_initial_state_true, logp_transition_true
(array([-0.10536052, -2.40794561, -4.60517019]),
 array([[-0.10536052, -2.40794561, -4.60517019],
        [-2.30258509, -0.22314355, -2.30258509],
        [-1.60943791, -2.30258509, -0.35667494]]))
# We will observe 70 HMM processes, each with a total of 50 steps
n_obs = 70
n_steps = 50

我们编写一个辅助函数来生成一个单一的HMM过程,并创建我们的模拟数据

def simulate_hmm(p_initial_state, p_transition, emission_signal, emission_noise, n_steps, rng):
    """Generate hidden state and emission from our HMM model."""

    possible_states = np.array([0, 1, 2])

    hidden_states = []
    initial_state = rng.choice(possible_states, p=p_initial_state)
    hidden_states.append(initial_state)
    for step in range(n_steps):
        new_hidden_state = rng.choice(possible_states, p=p_transition[hidden_states[-1]])
        hidden_states.append(new_hidden_state)
    hidden_states = np.array(hidden_states)

    emissions = rng.normal(
        (hidden_states + 1) * emission_signal,
        emission_noise,
    )

    return hidden_states, emissions
single_hmm_hidden_state, single_hmm_emission = simulate_hmm(
    p_initial_state_true,
    p_transition_true,
    emission_signal_true,
    emission_noise_true,
    n_steps,
    rng,
)
print(single_hmm_hidden_state)
print(np.round(single_hmm_emission, 2))
[0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 1
 1 1 0 0 0 0 0 0 0 0 0 0 0 0]
[1.34 0.79 1.07 1.25 1.33 0.98 1.97 2.45 2.21 2.19 2.21 2.15 1.24 1.16
 0.78 1.18 1.34 2.21 2.44 2.14 2.15 2.38 2.27 2.33 2.26 2.37 2.45 2.36
 2.35 2.32 2.36 2.21 2.27 2.32 3.68 3.32 2.39 2.14 1.99 1.32 1.15 1.31
 1.25 1.17 1.06 0.91 0.88 1.17 1.   1.01 0.87]
hidden_state_true = []
emission_observed = []

for i in range(n_obs):
    hidden_state, emission = simulate_hmm(
        p_initial_state_true,
        p_transition_true,
        emission_signal_true,
        emission_noise_true,
        n_steps,
        rng,
    )
    hidden_state_true.append(hidden_state)
    emission_observed.append(emission)

hidden_state = np.array(hidden_state_true)
emission_observed = np.array(emission_observed)
fig, ax = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
# Plot first five hmm processes
for i in range(4):
    ax[0].plot(hidden_state_true[i] + i * 0.02, color=f"C{i}", lw=2, alpha=0.4)
    ax[1].plot(emission_observed[i], color=f"C{i}", lw=2, alpha=0.4)
ax[0].set_yticks([0, 1, 2])
ax[0].set_ylabel("hidden state")
ax[1].set_ylabel("observed emmission")
ax[1].set_xlabel("step")
fig.suptitle("Simulated data");
../../../_images/810705521c6d764cc52621beb88fb0f1160e1a92c804d6708c7c4f84aa34a535.png

上图显示了我们模拟数据的隐藏状态及其相应的观察发射。稍后,我们将使用这些数据来对真实模型参数进行后验推断。

使用JAX计算边际HMM似然#

我们将编写一个JAX函数来计算我们HMM模型的似然性,对隐藏状态进行边缘化。这使得对剩余模型参数的采样更加高效。为了实现这一点,我们将使用著名的前向算法,在数值稳定性方面使用对数尺度。

我们将利用JAX scan 来获得一个高效且可微分的对数似然,并使用方便的 vmap 来自动向量化这个对数似然,使其适用于多个观察过程。

我们的核心JAX函数计算单个HMM过程的边际对数似然

def hmm_logp(
    emission_observed,
    emission_signal,
    emission_noise,
    logp_initial_state,
    logp_transition,
):
    """Compute the marginal log-likelihood of a single HMM process."""

    hidden_states = np.array([0, 1, 2])

    # Compute log-likelihood of observed emissions for each (step x possible hidden state)
    logp_emission = jsp.stats.norm.logpdf(
        emission_observed[:, None],
        (hidden_states + 1) * emission_signal,
        emission_noise,
    )

    # We use the forward_algorithm to compute log_alpha(x_t) = logp(x_t, y_1:t)
    log_alpha = logp_initial_state + logp_emission[0]
    log_alpha, _ = jax.lax.scan(
        f=lambda log_alpha_prev, logp_emission: (
            jsp.special.logsumexp(log_alpha_prev + logp_transition.T, axis=-1) + logp_emission,
            None,
        ),
        init=log_alpha,
        xs=logp_emission[1:],
    )

    return jsp.special.logsumexp(log_alpha)

让我们用真实参数和第一个模拟的HMM过程来测试它

hmm_logp(
    emission_observed[0],
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DeviceArray(-3.93533794, dtype=float64)

我们现在使用vmap来对多个观测值进行向量化核心函数。

def vec_hmm_logp(*args):
    vmap = jax.vmap(
        hmm_logp,
        # Only the first argument, needs to be vectorized
        in_axes=(0, None, None, None, None),
    )
    # For simplicity we sum across all the HMM processes
    return jnp.sum(vmap(*args))


# We jit it for better performance!
jitted_vec_hmm_logp = jax.jit(vec_hmm_logp)

传递一个仅包含第一个模拟HMM过程的行矩阵应返回相同的结果

jitted_vec_hmm_logp(
    emission_observed[0][None, :],
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)
DeviceArray(-3.93533794, dtype=float64)

然而,我们的目标是计算所有模拟数据的联合对数似然

jitted_vec_hmm_logp(
    emission_observed,
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)
DeviceArray(-37.00348857, dtype=float64)

我们还将要求JAX为我们提供关于每个输入的梯度函数。这在后面会派上用场。

jitted_vec_hmm_logp_grad = jax.jit(jax.grad(vec_hmm_logp, argnums=list(range(5))))

让我们打印出相对于 emission_signal 的梯度。我们将检查这个值在将我们的函数包装在 PyTensor 中后是否保持不变。

jitted_vec_hmm_logp_grad(
    emission_observed,
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)[1]
DeviceArray(-297.86490611, dtype=float64, weak_type=True)

在 PyTensor 中包装 JAX 函数#

现在我们准备将 JAX jitted 函数封装在一个 PyTensor Op 中,然后我们可以在 PyMC 模型中使用它。我们建议您查看 PyTensor 的官方 Op 文档,如果您想更详细地了解它。

简而言之,我们将继承自 Op 并定义以下方法:

  1. make_node: 创建一个Apply节点,该节点将我们的操作的符号输入和输出组合在一起

  2. perform: 返回给定具体输入值的操作评估的Python代码

  3. grad: 返回一个表示输出成本相对于其输入的梯度表达式的PyTensor符号图

对于 grad,我们将创建第二个 Op,它包装了我们上面jit版本的梯度

class HMMLogpOp(Op):
    def make_node(
        self,
        emission_observed,
        emission_signal,
        emission_noise,
        logp_initial_state,
        logp_transition,
    ):
        # Convert our inputs to symbolic variables
        inputs = [
            pt.as_tensor_variable(emission_observed),
            pt.as_tensor_variable(emission_signal),
            pt.as_tensor_variable(emission_noise),
            pt.as_tensor_variable(logp_initial_state),
            pt.as_tensor_variable(logp_transition),
        ]
        # Define the type of the output returned by the wrapped JAX function
        outputs = [pt.dscalar()]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        result = jitted_vec_hmm_logp(*inputs)
        # PyTensor raises an error if the dtype of the returned output is not
        # exactly the one expected from the Apply node (in this case
        # `dscalar`, which stands for float64 scalar), so we make sure
        # to convert to the expected dtype. To avoid unnecessary conversions
        # you should make sure the expected output defined in `make_node`
        # is already of the correct dtype
        outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)

    def grad(self, inputs, output_gradients):
        (
            grad_wrt_emission_obsered,
            grad_wrt_emission_signal,
            grad_wrt_emission_noise,
            grad_wrt_logp_initial_state,
            grad_wrt_logp_transition,
        ) = hmm_logp_grad_op(*inputs)
        # If there are inputs for which the gradients will never be needed or cannot
        # be computed, `pytensor.gradient.grad_not_implemented` should  be used as the
        # output gradient for that input.
        output_gradient = output_gradients[0]
        return [
            output_gradient * grad_wrt_emission_obsered,
            output_gradient * grad_wrt_emission_signal,
            output_gradient * grad_wrt_emission_noise,
            output_gradient * grad_wrt_logp_initial_state,
            output_gradient * grad_wrt_logp_transition,
        ]


class HMMLogpGradOp(Op):
    def make_node(
        self,
        emission_observed,
        emission_signal,
        emission_noise,
        logp_initial_state,
        logp_transition,
    ):
        inputs = [
            pt.as_tensor_variable(emission_observed),
            pt.as_tensor_variable(emission_signal),
            pt.as_tensor_variable(emission_noise),
            pt.as_tensor_variable(logp_initial_state),
            pt.as_tensor_variable(logp_transition),
        ]
        # This `Op` will return one gradient per input. For simplicity, we assume
        # each output is of the same type as the input. In practice, you should use
        # the exact dtype to avoid overhead when saving the results of the computation
        # in `perform`
        outputs = [inp.type() for inp in inputs]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        (
            grad_wrt_emission_obsered_result,
            grad_wrt_emission_signal_result,
            grad_wrt_emission_noise_result,
            grad_wrt_logp_initial_state_result,
            grad_wrt_logp_transition_result,
        ) = jitted_vec_hmm_logp_grad(*inputs)
        outputs[0][0] = np.asarray(grad_wrt_emission_obsered_result, dtype=node.outputs[0].dtype)
        outputs[1][0] = np.asarray(grad_wrt_emission_signal_result, dtype=node.outputs[1].dtype)
        outputs[2][0] = np.asarray(grad_wrt_emission_noise_result, dtype=node.outputs[2].dtype)
        outputs[3][0] = np.asarray(grad_wrt_logp_initial_state_result, dtype=node.outputs[3].dtype)
        outputs[4][0] = np.asarray(grad_wrt_logp_transition_result, dtype=node.outputs[4].dtype)


# Initialize our `Op`s
hmm_logp_op = HMMLogpOp()
hmm_logp_grad_op = HMMLogpGradOp()

我们建议使用调试助手 eval 方法来确认我们正确指定了所有内容。我们应该得到与之前相同的输出:

hmm_logp_op(
    emission_observed,
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
).eval()
array(-37.00348857)
hmm_logp_grad_op(
    emission_observed,
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)[1].eval()
array(-297.86490611)

检查我们的 Op 的梯度是否可以通过 PyTensor 的 grad 接口请求也很有用:

# We define the symbolic `emission_signal` variable outside of the `Op`
# so that we can request the gradient wrt to it
emission_signal_variable = pt.as_tensor_variable(emission_signal_true)
x = hmm_logp_op(
    emission_observed,
    emission_signal_variable,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)
x_grad_wrt_emission_signal = pt.grad(x, wrt=emission_signal_variable)
x_grad_wrt_emission_signal.eval()
array(-297.86490611)

使用PyMC进行采样#

我们现在准备使用PyMC对我们的HMM模型进行推断。我们将为每个模型参数定义先验,并使用Potential将联合对数似然项添加到我们的模型中。

with pm.Model() as model:
    emission_signal = pm.Normal("emission_signal", 0, 1)
    emission_noise = pm.HalfNormal("emission_noise", 1)

    p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
    logp_initial_state = pt.log(p_initial_state)

    p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
    logp_transition = pt.log(p_transition)

    loglike = pm.Potential(
        "hmm_loglike",
        hmm_logp_op(
            emission_observed,
            emission_signal,
            emission_noise,
            logp_initial_state,
            logp_transition,
        ),
    )

在开始采样之前,我们检查模型初始点处每个变量的logp。错误通常以初始概率的nan-inf的形式表现出来。

initial_point = model.initial_point()
initial_point
{'emission_signal': array(0.),
 'emission_noise_log__': array(0.),
 'p_initial_state_simplex__': array([0., 0.]),
 'p_transition_simplex__': array([[0., 0.],
        [0., 0.],
        [0., 0.]])}
model.point_logps(initial_point)
{'emission_signal': -0.92,
 'emission_noise': -0.73,
 'p_initial_state': -1.5,
 'p_transition': -4.51,
 'hmm_loglike': -9812.67}

我们现在准备好采样了!

with model:
    idata = pm.sample(chains=2, cores=1)
Auto-assigning NUTS sampler...
INFO:pymc:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc:Initializing NUTS using jitter+adapt_diag...
/home/ricardo/Documents/Projects/pymc/pymc/pytensorf.py:1005: UserWarning: The parameter 'updates' of pytensor.function() expects an OrderedDict, got <class 'dict'>. Using a standard dictionary here results in non-deterministic behavior. You should use an OrderedDict if you are using Python 2.7 (collections.OrderedDict for older python), or use a list of (shared, update) pairs. Do not just convert your dictionary to this type before the call as the conversion will still be non-deterministic.
  pytensor_function = pytensor.function(
Sequential sampling (2 chains in 1 job)
INFO:pymc:Sequential sampling (2 chains in 1 job)
NUTS: [emission_signal, emission_noise, p_initial_state, p_transition]
INFO:pymc:NUTS: [emission_signal, emission_noise, p_initial_state, p_transition]
100.00% [2000/2000 00:52<00:00 Sampling chain 0, 0 divergences]
100.00% [2000/2000 00:56<00:00 Sampling chain 1, 0 divergences]
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 109 seconds.
INFO:pymc:Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 109 seconds.
az.plot_trace(idata);
../../../_images/5090654651d389883cc48f20c05357c2372bfad01a483ca801fb963cdf59d17f.png
true_values = [
    emission_signal_true,
    emission_noise_true,
    *p_initial_state_true,
    *p_transition_true.ravel(),
]

az.plot_posterior(idata, ref_val=true_values, grid=(3, 5));
../../../_images/a8769280c97becd0815d1078af137e8cd18169bba7576a300dc231cf207bd4bf.png

后验分布看起来合理地集中在用于生成我们数据的真值周围。

解包包装的JAX函数#

如前所述,PyTensor可以将整个图编译为JAX。为此,它需要知道如何将图中的每个Op转换为JAX函数。这可以通过调度pytensor.link.jax.dispatch.jax_funcify()来实现。大多数默认的PyTensor Op已经具有这样的调度函数,但我们需要为我们的自定义HMMLogpOp添加一个新的调度函数,因为PyTensor之前从未见过它。

为此,我们需要一个函数,该函数返回(另一个)JAX函数,该函数执行与我们的 perform 方法相同的计算。幸运的是,我们正是从这样一个函数开始的,所以这只需要3行简短的代码。

@jax_funcify.register(HMMLogpOp)
def hmm_logp_dispatch(op, **kwargs):
    return vec_hmm_logp

注意

我们不会返回已编译的函数,以便在转换为JAX后,整个PyTensor图可以一起编译。

为了更好地理解 Op JAX 转换,我们建议阅读 PyTensor 的 为 Ops 添加 JAX 和 Numba 支持指南

我们可以通过使用mode="JAX"编译pytensor.function()来测试我们的转换函数是否正常工作:

out = hmm_logp_op(
    emission_observed,
    emission_signal_true,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)
jax_fn = pytensor.function(inputs=[], outputs=out, mode="JAX")
jax_fn()
DeviceArray(-37.00348857, dtype=float64)

我们还可以编译一个JAX函数,该函数计算我们PyMC模型中每个变量的对数概率,类似于point_logps()。我们将使用辅助方法compile_fn()

model_logp_jax_fn = model.compile_fn(model.logp(sum=False), mode="JAX")
model_logp_jax_fn(initial_point)
[DeviceArray(-0.91893853, dtype=float64),
 DeviceArray(-0.72579135, dtype=float64),
 DeviceArray(-1.5040774, dtype=float64),
 DeviceArray([-1.5040774, -1.5040774, -1.5040774], dtype=float64),
 DeviceArray(-9812.66649064, dtype=float64)]

请注意,我们可以添加一个同样简单的函数来转换我们的 HMMLogpGradOp,以防我们想要将 PyTensor 梯度图转换为 JAX。在我们的例子中,我们不需要这样做,因为我们将会依赖 JAX 的 grad 函数(或者更准确地说,NumPyro 将会依赖它)从我们编译的 JAX 函数中再次获取这些信息。

我们在本文档的末尾包含了一个简短讨论,以帮助您更好地理解使用 PyTensor 图与 JAX 函数之间的权衡,以及何时可能需要使用其中之一。

使用NumPyro进行采样#

既然我们知道我们的模型 logp 可以完全编译为 JAX,我们可以使用方便的 pymc.sampling_jax.sample_numpyro_nuts() 来使用 NumPyro 中实现的纯 JAX 采样器对我们的模型进行采样。

with model:
    idata_numpyro = pm.sampling_jax.sample_numpyro_nuts(chains=2, progressbar=False)
/home/ricardo/miniconda3/envs/pymc-examples/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
/home/ricardo/Documents/Projects/pymc/pymc/pytensorf.py:1005: UserWarning: The parameter 'updates' of pytensor.function() expects an OrderedDict, got <class 'dict'>. Using a standard dictionary here results in non-deterministic behavior. You should use an OrderedDict if you are using Python 2.7 (collections.OrderedDict for older python), or use a list of (shared, update) pairs. Do not just convert your dictionary to this type before the call as the conversion will still be non-deterministic.
  pytensor_function = pytensor.function(
Compiling...
Compilation time =  0:00:01.897853
Sampling...
Sampling time =  0:00:47.542330
Transforming variables...
Transformation time =  0:00:00.399051
az.plot_trace(idata_numpyro);
../../../_images/656b656d2270592f99945b396ef0eec4d113b030f9f6bf55c2972550e9c00bae.png
az.plot_posterior(idata_numpyro, ref_val=true_values, grid=(3, 5));
../../../_images/6faef3869dcdd442e09dca9df2504f2ca00d5421493b5312dc5248951494cd89.png

正如预期的那样,采样结果看起来非常相似!

根据您使用的模型和计算机架构,纯 JAX 采样器可以提供显著的速度提升。

关于使用 PyTensor 与 JAX 的一些简要说明#

何时应该使用JAX?#

正如我们所见,在 PyTensor 图和 JAX 函数之间进行接口对接是非常直接的。

当你想将之前实现的JAX函数与PyMC模型结合时,这会非常方便。在这个例子中,我们使用了边缘化的HMM对数似然,但同样的策略可以用于使用深度神经网络或微分方程进行贝叶斯推断,或者在贝叶斯模型的上下文中使用JAX实现的几乎任何其他函数。

如果你需要利用 JAX 的独特功能,如矢量化、对树结构的支持或其细粒度的并行化,以及 GPU 和 TPU 功能,那么这也是值得的。

什么时候不应该使用JAX?#

与JAX类似,PyTensor的目标是模仿NumPy和Scipy的API,因此在PyTensor中编写代码的感觉应该与在这些库中编写代码非常相似。

然而,使用PyTensor有一些优势:

  1. PyTensor 图比 JAX 函数更容易 检查和调试

  2. PyTensor 拥有智能的 优化和稳定化程序,这些在 JAX 中是不可能实现或未实现的

  3. PyTensor 图可以在创建后轻松进行操作

第二点意味着如果你的图表是用PyTensor编写的,它们可能会表现得更好。通常情况下,你不需要担心使用像log1plogsumexp这样的专门函数,因为PyTensor能够检测到等效的简单表达式,并将它们替换为它们的专门版本。重要的是,当你的图表稍后编译为JAX时,你仍然可以从这些优化中受益。

关键在于 PyTensor 无法推理 JAX 函数,以及与之相关的 Op 包装它们。这意味着图表中“隐藏”在 JAX 函数内的部分越大,用户从 PyTensor 的重写和调试能力中获得的收益就越少。

第三点对于库开发者来说更为重要。这也是PyMC开发者选择使用PyTensor(以及之前的Theano)作为其后端的主要原因。PyMC提供的许多面向用户的工具都依赖于能够轻松解析和操作PyTensor图的能力。

奖励:使用一个可以计算自身梯度的单一操作#

我们不得不创建两个 Op,一个用于我们关心的函数,另一个用于其梯度。然而,JAX 提供了一个 value_and_grad 工具,可以返回函数的值及其梯度。我们可以做一些类似的事情,并且通过巧妙处理,只需一个 Op 即可实现。

通过这样做,我们可以(潜在地)节省内存并重用函数及其梯度之间共享的计算。这在处理非常大的 JAX 函数时可能相关。

请注意,这仅在你对使用 PyTensor 计算关于你的 Op 的梯度感兴趣时才有用。如果你的最终目标是编译你的图到 JAX,并且只在那时计算梯度(如 NumPyro 所做的),那么最好使用第一种方法。在这种情况下,你甚至不需要实现 grad 方法和相关的 Op

jitted_hmm_logp_value_and_grad = jax.jit(jax.value_and_grad(vec_hmm_logp, argnums=list(range(5))))
class HmmLogpValueGradOp(Op):
    # By default only show the first output, and "hide" the other ones
    default_output = 0

    def make_node(self, *inputs):
        inputs = [pt.as_tensor_variable(inp) for inp in inputs]
        # We now have one output for the function value, and one output for each gradient
        outputs = [pt.dscalar()] + [inp.type() for inp in inputs]
        return Apply(self, inputs, outputs)

    def perform(self, node, inputs, outputs):
        result, grad_results = jitted_hmm_logp_value_and_grad(*inputs)
        outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
        for i, grad_result in enumerate(grad_results, start=1):
            outputs[i][0] = np.asarray(grad_result, dtype=node.outputs[i].dtype)

    def grad(self, inputs, output_gradients):
        # The `Op` computes its own gradients, so we call it again.
        value = self(*inputs)
        # We hid the gradient outputs by setting `default_update=0`, but we
        # can retrieve them anytime by accessing the `Apply` node via `value.owner`
        gradients = value.owner.outputs[1:]

        # Make sure the user is not trying to take the gradient with respect to
        # the gradient outputs! That would require computing the second order
        # gradients
        assert all(
            isinstance(g.type, pytensor.gradient.DisconnectedType) for g in output_gradients[1:]
        )

        return [output_gradients[0] * grad for grad in gradients]


hmm_logp_value_grad_op = HmmLogpValueGradOp()

我们再次检查是否可以使用 PyTensor grad 接口获取梯度

emission_signal_variable = pt.as_tensor_variable(emission_signal_true)
# Only the first output is assigned to the variable `x`, due to `default_output=0`
x = hmm_logp_value_grad_op(
    emission_observed,
    emission_signal_variable,
    emission_noise_true,
    logp_initial_state_true,
    logp_transition_true,
)
pt.grad(x, emission_signal_variable).eval()
array(-297.86490611)

作者#

Ricardo Vieira 于2022年3月24日撰写(pymc-examples#302

水印#

%load_ext watermark
%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray
Last updated: Mon Apr 11 2022

Python implementation: CPython
Python version       : 3.10.2
IPython version      : 8.1.1

pytensor: 2.5.1
aeppl : 0.0.27
xarray: 2022.3.0

matplotlib: 3.5.1
jax       : 0.3.4
pytensor    : 2.5.1
arviz     : 0.12.0
pymc      : 4.0.0b6
numpy     : 1.22.3

Watermark: 2.3.0