推理工具

enable_validation(is_validate)[source]
is_validation_enabled()[source]
validation_enabled(is_validate=True)[source]

模型检查

get_dependencies(model: Callable, model_args: Optional[tuple] = None, model_kwargs: Optional[dict] = None, include_deterministic: bool = False) Dict[str, object][source]

推断关于条件模型的依赖结构。

这将返回一个结构如下的嵌套字典:

{
    "prior_dependencies": {
        "variable1": {"variable1": set()},
        "variable2": {"variable1": set(), "variable2": set()},
        ...
    },
    "posterior_dependencies": {
        "variable1": {"variable1": {"plate1"}, "variable2": set()},
        ...
    },
}

其中

  • prior_dependencies 是一个字典,将下游的潜在变量和观测变量映射到它们所依赖的上游潜在变量的字典,这些字典映射到引入完全依赖关系的板块集合。也就是说,包含的板块引入了像完全二分图中那样的二次依赖关系,而排除的板块仅引入了像独立平行边集合中的线性依赖关系。先验依赖关系遵循原始模型的顺序。

  • posterior_dependencies 是一个类似的字典,但将潜在变量映射到它们在后验中依赖的潜在或观察到的站点。后验依赖关系与模型顺序相反。

依赖关系省略了 pyro.deterministic 站点和 pyro.sample(..., Delta(...)) 站点。

示例

这里是一个没有板块的简单示例。我们看到每个节点都依赖于自身,并且只有潜在变量出现在后验中:

def model_1():
    a = pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.Normal(a, 1), obs=torch.tensor(0.0))

assert get_dependencies(model_1) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set()},
    },
}

这是一个例子,其中两个变量 ab 在先前是条件独立的,但在后验中变得条件依赖于所谓的碰撞变量 c,它们都依赖于这个变量。这在图模型文献中被称为“道德化”:

def model_2():
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.LogNormal(0, 1))
    c = pyro.sample("c", dist.Normal(a, b))
    pyro.sample("d", dist.Normal(c, 1), obs=torch.tensor(0.))

assert get_dependencies(model_2) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"b": set()},
        "c": {"a": set(), "b": set(), "c": set()},
        "d": {"c": set(), "d": set()},
    },
    "posterior_dependencies": {
        "a": {"a": set(), "b": set(), "c": set()},
        "b": {"b": set(), "c": set()},
        "c": {"c": set(), "d": set()},
    },
}

在存在板块的情况下,依赖关系可能会更加复杂。到目前为止,所有的字典值都是空的板块集合,但在以下的后验中,我们看到a在板块p上依赖于自身。这意味着,在a的元素中,例如a[0]依赖于a[1](这就是为什么我们明确允许变量依赖于自身):

def model_3():
    with pyro.plate("p", 5):
        a = pyro.sample("a", dist.Normal(0, 1))
    pyro.sample("b", dist.Normal(a.sum(), 1), obs=torch.tensor(0.0))

assert get_dependencies(model_3) == {
    "prior_dependencies": {
        "a": {"a": set()},
        "b": {"a": set(), "b": set()},
    },
    "posterior_dependencies": {
        "a": {"a": {"p"}, "b": set()},
    },
}
[1] S.Webb, A.Goliński, R.Zinkov, N.Siddharth, T.Rainforth, Y.W.Teh, F.Wood (2018)

“生成模型的有效摊销推理的忠实反演” https://dl.acm.org/doi/10.5555/3327144.3327229

Parameters
  • model (callable) – 一个模型。

  • model_args (tuple) – 可选的模型参数元组。

  • model_kwargs (dict) – 可选的模型参数字典。

  • include_deterministic (bool) – 是否包含确定性站点。

Returns

元数据的字典(见上文)。

Return type

dict

render_model(model: Callable, model_args: Optional[Union[tuple, List[tuple]]] = None, model_kwargs: Optional[Union[dict, List[dict]]] = None, filename: Optional[str] = None, render_distributions: bool = False, render_params: bool = False, render_deterministic: bool = False) graphviz.graphs.Digraph[source]

使用graphviz渲染模型。

如果提供了filename,这将保存一个图像;否则这将绘制图形。例如用法请参见模型渲染教程

Parameters
  • model – 要渲染的模型。

  • model_args – 传递给模型的元组位置参数,或半监督模型的元组列表。

  • model_kwargs – 传递给模型的关键字参数字典,或用于半监督模型的字典列表。

  • filename (str) – 要保存渲染模型的文件的名称或文件路径。

  • render_distributions (bool) – 是否在图中包含RV分布注释(和参数约束)。

  • render_params (bool) – 是否在图中显示参数。

  • render_deterministic (bool) – 是否包含确定性站点。

Returns

模型图。

Return type

graphviz.Digraph

交互式先验调优

class Resampler(guide: Callable, simulator: Optional[Callable] = None, *, num_guide_samples: int, max_plate_nesting: Optional[int] = None)[source]

用于生成模型交互式调优的重采样器,通常在贝叶斯工作流程的早期步骤中执行先验预测检查时使用。

这旨在作为一个计算缓存,以加速基于下游模拟样本的先验分布参数的交互式调整。其思想是模拟可能非常昂贵,但当稍微调整参数分布的参数时,可以通过重要性重采样重用大部分先前的样本。

Parameters
  • 指南 (可调用) – 一个不需要参数的pyro模型。该指南应该是扩散的,覆盖比后续传递给sample()model更多的空间。必须能够通过pyro.plate进行向量化。

  • simulator (callable) – 一个可选的更大的pyro模型,包含指南的潜在变量的超集。必须能够通过pyro.plate进行向量化。

  • num_guide_samples (int) – 从指南中抽取的初始样本数量。这个数量应该比后续调用sample()时请求的num_samples大得多。

  • max_plate_nesting (int) – 模型中的最大板块嵌套。 如果未指定,将通过运行指南来猜测。

sample(model: Callable, num_samples: int, stable: bool = True) Dict[str, torch.Tensor][source]

绘制一组最多为num_samples的模型样本,可以选择性地通过simulator进行扩展。

在内部,这种重要性重新采样了由guide.__init__()中生成的样本,并且不会重新运行guidesimulator。如果原始引导样本对模型分布的覆盖较差,样本将显示出较低的多样性。

Parameters
  • model (可调用的) – 一个与原始 guide 具有相同潜在变量的模型。必须能够通过 pyro.plate 进行向量化。

  • num_samples (int) – 要抽取的样本数量。

  • stable (bool) – 是否使用分段常数多项式采样。设置为True用于可视化,False用于蒙特卡罗积分。默认为True。

Returns

堆叠样本的字典。

Return type

Dict[str, torch.Tensor]