pymc.sampling.jax.sample_blackjax_nuts#

pymc.sampling.jax.sample_blackjax_nuts(draws=1000, tune=1000, chains=4, target_accept=0.8, random_seed=None, initvals=None, model=None, var_names=None, progress_bar=False, keep_untransformed=False, chain_method='parallel', postprocessing_backend=None, postprocessing_vectorize='scan', idata_kwargs=None, adaptation_kwargs=None, postprocessing_chunks=None)[源代码]#

使用 blackjax 库中的 NUTS 方法从后验分布中抽取样本。

参数:
抽取 : int, 默认值 1000python:int, 默认值 1000

要抽取的样本数量。默认情况下,调整后的样本数量会被丢弃。

tune : int, 默认值 1000python:int, 默认值 1000

调整的迭代次数。采样器在调整过程中会调整步长、缩放比例等。调整样本将额外抽取,数量由 draws 参数指定。

链条 : int, 默认值为 4python:int, 默认值为 4

要采样的链数。

target_acceptpython:float 在 [0, 1] 范围内。

步长调整到接近这个接受率。对于有问题的后验分布,像0.9或0.95这样的较高值通常效果更好。

random_seedpython:int, RandomState 或 Generator, 可选

采样步骤中使用的随机种子。

initvals: StartDict 或 Sequence[Optional[StartDict]],可选

作为字典(或字典序列)提供的随机变量的初始值,将随机变量(按名称或引用)映射到所需的初始值。

模型 : Model, 可选模型,可选

要从其采样的模型。该模型需要有自由随机变量。当在 with 模型上下文中时,它默认为该模型,否则必须显式传递模型。

var_names : sequencestr,可选python:str 的 python:sequence,可选

要计算后验样本的变量名称。默认为后验中的所有变量。

keep_untransformed : 布尔值,默认值 Falsebool, 默认 python:False

在后验样本中包含未转换的变量。默认为 False。

chain_method : str, 默认值为 “parallel”python:str, 默认值 “parallel”

指定样本应如何绘制。选项包括“并行”和“向量化”。

postprocessing_backend: 可选[Literal[“cpu”, “gpu”]], 默认 None,

指定后处理应如何计算。gpu 或 cpu

postprocessing_vectorize: Literal[“vmap”, “scan”], 默认 “scan”

如何向量化后处理:vmap 或顺序扫描

idata_kwargs : dict, 可选python:dict, 可选

用于 arviz.from_dict() 的关键字参数。它还接受一个布尔值作为 log_likelihood 键的值,以指示不应在返回的对象中包含逐点对数似然。如果未在 idata_kwargs 中提供 observed_dataconstant_datacoordsdims 的值,则从 model 参数推断这些值。如果提供了 coordsdims,则使用它们来更新推断的字典。

返回:
InferenceData

ArviZ InferenceData 对象包含后验样本,以及它们各自的样本统计和逐点对数似然值(除非使用 idata_kwargs 跳过)。