普丁(效果处理器)

在Pyro内置的推理算法之下,有一个可组合的效果处理程序库,用于创建新的推理算法和处理概率程序。Pyro的推理算法都是通过将这些处理程序应用于随机函数来构建的。为了对效果处理程序是什么以及它们解决了什么问题有一个全面的了解,请阅读Matija Pretnar的代数效应和处理程序简介

处理器

Poutine 是一个用于记录和修改 Pyro 程序行为的可组合效果处理程序库。这些较低层次的组件简化了新推理算法和行为的实现。

处理程序可以用作高阶函数、装饰器或上下文管理器,以修改函数或代码块的行为:

例如,考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用condition将样本站点标记为已观察,它返回一个与model具有相同输入和输出签名的可调用对象:

>>> conditioned_model = poutine.condition(model, data={"z": 1.0})

我们也可以使用处理程序作为装饰器:

>>> @pyro.condition(data={"z": 1.0})
... def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

或者作为上下文管理器:

>>> with pyro.condition(data={"z": 1.0}):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(0., s))
...     y = z ** 2

处理程序可以自由组合:

>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)

许多推理算法或算法组件可以用几行代码实现:

guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
block(fn: None = None, hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) pyro.poutine.block_messenger.BlockMessenger[source]
block(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

BlockMessenger的便捷包装器

此处理程序选择性地从外部世界隐藏Pyro原始站点。 默认行为:阻止所有内容。

如果满足以下至少一个条件,则站点被隐藏:

  1. hide_fn(msg) is True(not expose_fn(msg)) is True

  2. msg["name"] in hide

  3. msg["type"] in hide_types

  4. msg["name"] not in expose and msg["type"] not in expose_types

  5. hide, hide_types, 和 expose_types 都是 None

例如,假设随机函数 fn 有两个样本点“a”和“b”。 那么在 BlockMessenger(fn, hide=["a"]) 之外的任何效果 都不会应用于点“a”,而只会看到点“b”:

>>> def fn():
...     a = pyro.sample("a", dist.Normal(0., 1.))
...     return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = pyro.poutine.trace(fn)
>>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • hide_fn – 一个函数,它接受一个站点并返回True以隐藏该站点,或返回False/None以暴露它。如果指定了此函数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。

  • expose_fn – 一个函数,接受一个站点并返回True以暴露该站点,或返回False/None以隐藏它。如果指定了此参数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。

  • hide_all (bool) – 隐藏所有站点

  • expose_all (bool) – 正常暴露所有站点

  • 隐藏 (list) – 要隐藏的站点名称列表

  • expose (list) – 要暴露的站点名称列表,其他站点将被隐藏

  • hide_types (list) – 要隐藏的站点类型列表

  • expose_types (list) – 要暴露的站点类型列表,其他所有类型将被隐藏

Returns

使用BlockMessenger装饰的随机函数

broadcast(fn: None = None) pyro.poutine.broadcast_messenger.BroadcastMessenger[source]
broadcast(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

BroadcastMessenger的便捷包装器

当在单个或嵌套的plate上下文中时,自动广播随机函数在采样点的批次形状。现有的batch_shape必须与安装在cond_indep_stack中的plate上下文的大小可广播。

注意下面的model_automatic_broadcast如何自动扩展分布批处理形状。这使得将Pyro模型模块化变得容易,因为子组件对包装的plate上下文是无感知的。

>>> def model_broadcast_by_hand():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5)
...                                                .expand_by(100))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
>>> @poutine.broadcast
... def model_automatic_broadcast():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
collapse(fn: None = None, *args: Any, **kwargs: Any) pyro.poutine.collapse_messenger.CollapseMessenger[source]
collapse(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], *args: Any, **kwargs: Any) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

CollapseMessenger的便捷包装器

实验性功能:通过懒采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受Funsors而不是Tensors。这需要安装funsor

警告

这与自动猜测max_plate_nesting不兼容。如果在折叠的上下文中出现任何板块,您应该手动向您的推理算法声明max_plate_nesting(例如Trace_ELBO(max_plate_nesting=1))。

condition(data: Union[Dict[str, torch.Tensor], Trace]) pyro.poutine.condition_messenger.ConditionMessenger[source]
condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

ConditionMessenger 的便捷包装器

给定一个包含一些样本语句的随机函数和一个名称的观察字典, 将那些名称处的样本语句更改为具有这些值的观察。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要观察站点 z 的值,我们可以写

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

这相当于在model中将obs=value作为关键字参数添加到pyro.sample(“z”, …)中。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • data – 一个字典或一个 Trace

Returns

使用ConditionMessenger装饰的随机函数

do(data: Dict[str, Union[torch.Tensor, numbers.Number]]) pyro.poutine.do_messenger.DoMessenger[source]
do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

DoMessenger的便捷包装器

给定一个带有一些样本语句的随机函数和一个名称值的字典, 将这些站点的返回值设置为这些值, 就像它们被硬编码为这些值一样, 并引入具有相同名称的新样本站点, 这些站点的值不会传播。

condition()自由组合,以表示潜在结果的反事实分布。 有关更多细节和理论,请参见单世界干预图[1]。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要为站点 z 干预一个值,我们可以写

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

这相当于将z = pyro.sample(“z”, …)替换为 z = torch.tensor(1.) 并引入一个新的采样点pyro.sample(“z”, …),其值在其他地方未使用。

参考文献

[1] Single World Intervention Graphs: A Primer,

托马斯·理查森,詹姆斯·罗宾斯

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • data – 一个将样本站点名称映射到干预措施的 dict

Returns

使用DoMessenger装饰的随机函数

enum(fn: None = None, first_available_dim: Optional[int] = None) pyro.poutine.enum_messenger.EnumMessenger[source]
enum(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], first_available_dim: Optional[int] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

EnumMessenger的便捷包装器

并行枚举标记为infer={"enumerate": "parallel"}的离散样本点。

Parameters

first_available_dim (int) – 第一个可用于并行枚举的张量维度(从右开始计数)。这个维度和所有左侧的维度可能会被Pyro内部使用。这应该是一个负整数或None。

escape(escape_fn: Callable[[Message], bool]) pyro.poutine.escape_messenger.EscapeMessenger[source]
escape(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], escape_fn: Callable[[Message], bool]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

EscapeMessenger 的便捷包装器

通过引发 util.NonlocalExit 异常来进行非局部退出的 Messenger

infer_config(config_fn: Callable[[Message], InferDict]) pyro.poutine.infer_config_messenger.InferConfigMessenger[source]
infer_config(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config_fn: Callable[[Message], InferDict]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

InferConfigMessenger的便捷包装器

给定一个包含Pyro原始调用的可调用对象fn和一个接受跟踪站点并返回字典的可调用对象config_fn,将样本站点的infer kwarg的值更新为config_fn(site)。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • config_fn – 一个可调用对象,接收一个站点并返回一个推断字典

Returns

使用InferConfigMessenger装饰的随机函数

lift(prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) pyro.poutine.lift_messenger.LiftMessenger[source]
lift(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

LiftMessenger的便捷包装器

给定一个带有参数调用的随机函数和一个先验分布,创建一个随机函数,其中所有参数调用都被从先验中采样所替代。先验应该是一个可调用对象或一个名称到可调用对象的字典。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

lift 使 param 语句的行为类似于 sample 语句 使用 prior 中的分布。在这个例子中,站点 s 现在将表现得 好像它被替换为 s = pyro.sample("s", dist.Exponential(0.3)):

>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> tr.nodes["s"]["type"] == "sample"
True
>>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all())
False
Parameters
  • fn – 其参数将被提升为随机值的函数

  • prior – 以Distribution形式或随机函数字典形式存在的先验函数

Returns

fn 使用 LiftMessenger 装饰

markov(fn: None = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger[source]
markov(fn: Iterable[int] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger
markov(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

马尔可夫依赖声明。

这可以用于多种方式:

  • 作为上下文管理器

  • 作为递归函数的装饰器

  • 作为马尔可夫链的迭代器

Parameters
  • 历史 (int) – 从当前上下文中可见的先前上下文的数量。默认为1。如果为零,这类似于 pyro.plate

  • keep (bool) – 如果为真,帧是可重放的。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(取决于它们的共享条件)

  • dim (int) – 用于此独立性索引的可选维度。 接口存根,行为尚未实现。

  • name (str) – 一个可选的唯一名称,用于帮助推理算法在模型和指南之间匹配 pyro.markov() 站点。 接口存根,行为尚未实现。

mask(mask: Union[bool, torch.BoolTensor]) pyro.poutine.mask_messenger.MaskMessenger[source]
mask(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], mask: Union[bool, torch.BoolTensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

MaskMessenger的便捷包装器

给定一个带有一些批量样本语句和掩码张量的随机函数,按元素掩码掉一些样本语句。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • mask (torch.BoolTensor) – 一个{0,1}值的掩码张量 (1 包含一个站点,0 排除一个站点)

Returns

使用MaskMessenger装饰的随机函数

queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[source]

用于离散变量的顺序枚举。

给定一个随机函数和一个队列,从队列中的完整跟踪返回一个返回值。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • queue – 一个类似于 multiprocessing.Queue 的队列数据结构,用于保存部分跟踪信息

  • max_tries – 计算单个完整跟踪的最大尝试次数

  • extend_fn – 函数(可能是随机的),它接受部分轨迹和一个站点,并返回扩展轨迹的列表

  • escape_fn – 函数(可能是随机的),它接受部分轨迹和一个站点,并返回一个布尔值来决定是否退出

  • num_samples – extend_fn 返回的扩展轨迹的可选数量

Returns

使用poutine逻辑装饰的随机函数

reparam(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamMessenger[source]
reparam(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]

ReparamMessenger的便捷包装器

将每个受影响的采样点重新参数化为一个或多个辅助采样点,然后进行确定性变换 [1]。

要指定重新参数化器,请将config字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参见pyro.infer.reparam模块。

请注意,一些重新参数化器可以检查它们影响的函数的*args,**kwargs输入;这些重新参数化器需要使用poutine.reparam作为装饰器而不是上下文管理器。

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“概率程序的自动重新参数化” https://arxiv.org/pdf/1906.03028.pdf

Parameters

config (dictcallable) – 配置,可以是一个将站点名称映射到 Reparameterizer 的字典,或者是一个将站点映射到 Reparam 或 None 的函数。有关内置的配置策略,请参见 pyro.infer.reparam.strategies

replay(fn: None = None, trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) pyro.poutine.replay_messenger.ReplayMessenger[source]
replay(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

ReplayMessenger 的便捷包装器

给定一个包含Pyro原始调用的可调用对象, 返回一个运行原始可调用对象的可调用对象,在新跟踪中重用跟踪中的值

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

replay 使得 sample 语句表现得好像它们在跟踪中的相应位置采样了值一样:

>>> old_trace = pyro.poutine.trace(model).get_trace(1.0)
>>> replayed_model = pyro.poutine.replay(model, trace=old_trace)
>>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])
True
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • trace – 一个用于回放的Trace数据结构

  • params – 参数字典,包含参数名称和约束值,用于在函数中重放

Returns

一个用ReplayMessenger装饰的随机函数

scale(scale: Union[float, torch.Tensor]) pyro.poutine.scale_messenger.ScaleMessenger[source]
scale(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], scale: Union[float, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

ScaleMessenger的便捷包装器

给定一个包含一些样本语句和正比例因子的随机函数,缩放函数中所有样本和观察点的分数。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scale 以乘法方式缩放样本站点的对数概率:

>>> scaled_model = pyro.poutine.scale(model, scale=0.5)
>>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0)
>>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0)
>>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all())
True
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • scale – 一个正的比例因子

Returns

使用ScaleMessenger装饰的随机函数

seed(rng_seed: int) pyro.poutine.seed_messenger.SeedMessenger[source]
seed(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], rng_seed: int) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

SeedMessenger的便捷包装器

处理程序通过设置种子将随机数生成器设置为预定义状态。这与在调用fn之前调用pyro.set_rng_seed()相同。此处理程序对标准Pyro后端上的原始语句没有额外影响,但它可能会在其他后端(例如NumPy后端)中拦截pyro.sample调用。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)。

  • rng_seed (int) – 随机数生成器种子。

substitute(data: Dict[str, torch.Tensor]) pyro.poutine.substitute_messenger.SubstituteMessenger[source]
substitute(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

SubstituteMessenger的便捷包装器

给定一个带有参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都被替换为固定值。数据应该是一个名称到值的字典。考虑以下Pyro程序:

>>> def model(x):
...     a = pyro.param("a", torch.tensor(0.5))
...     x = pyro.sample("x", dist.Bernoulli(probs=a))
...     return x
>>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

在这个例子中,站点 a 现在将具有值 torch.tensor(0.3)。 :param data: 由站点名称键控的值的字典。 :returns: fn 装饰有 SubstituteMessenger

trace(fn: None = None, graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceMessenger[source]
trace(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]

TraceMessenger的便捷包装器

返回一个记录原始调用输入输出及其依赖关系的处理程序。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用trace记录其执行过程,并使用生成的数据结构来计算执行过程中所有采样点的对数联合概率或提取所有参数。

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • graph_type – 指定要构建的图形类型的字符串

  • param_only – 如果为true,仅记录参数而不记录样本

Returns

使用TraceMessenger装饰的随机函数

uncondition(fn: None = None) pyro.poutine.uncondition_messenger.UnconditionMessenger[source]
uncondition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]

UnconditionMessenger的便捷包装器

Messenger 强制从它们的分布中采样观察节点的值,忽略观察结果。

config_enumerate(guide=None, default='parallel', expand=False, num_samples=None, tmc='diagonal')[source]

配置指南中所有相关站点的枚举。这主要与TraceEnum_ELBO一起使用。

在配置离散变量的穷举枚举时,这会配置所有分布满足.has_enumerate_support == True的采样点。 在通过default="parallel", num_samples=n配置本地并行蒙特卡洛采样时,这会配置所有采样点。 这不会覆盖现有的注释infer={"enumerate": ...}

这可以用作一个函数:

guide = config_enumerate(guide)

或者作为装饰器:

@config_enumerate
def guide1(*args, **kwargs):
    ...

@config_enumerate(default="sequential", expand=True)
def guide2(*args, **kwargs):
    ...
Parameters
  • 指南 (可调用) – 一个pyro模型,将用作SVI中的指南。

  • 默认 (str) – 使用哪种枚举策略,可选“sequential”(顺序)、“parallel”(并行)或None。默认为“parallel”(并行)。

  • expand (bool) – 是否扩展枚举的样本值。详情请参见 enumerate_support()。 这仅适用于详尽枚举,其中 num_samples=None。 如果 num_samples 不是 None,则样本将始终被扩展。

  • num_samples (intNone) – 如果不是 None,则使用局部蒙特卡洛采样而不是穷举枚举。这对于连续和离散分布都有意义。

  • tmc (stringNone) – 在Tensor Monte Carlo中使用的“mixture”或“diagonal”策略

Returns

一个带注释的指南

Return type

可调用的

追踪

class Trace(graph_type: Literal['flat', 'dense'] = 'flat')[source]

基础类:object

表示执行跟踪中不同pyro原语之间关系的图数据结构。

Pyro程序的执行跟踪是记录该程序单次执行中对pyro.sample()pyro.param()的每次调用。跟踪是有向图,其节点表示原始调用或输入/输出,边表示这些原始调用之间的条件依赖关系。它们由poutine.trace创建和填充。

跟踪中的每个节点(或站点)包含站点的名称、输入和输出值,以及由推理算法或用户注释添加的额外元数据。在pyro.sample的情况下,跟踪还包括站点上的随机函数,以及用户添加的任何观察数据。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用pyro.poutine.trace记录其执行过程,并使用生成的数据结构来计算执行过程中所有采样点的对数联合概率或提取所有参数。

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]

我们还可以检查或操作跟踪中的单个节点。 trace.nodes 包含一个 collections.OrderedDict 的站点名称和元数据,对应于 x, s, z, 以及返回值:

>>> list(name for name in trace.nodes.keys())  
["_INPUT", "s", "z", "_RETURN"]

trace.nodes 的值是节点元数据的字典:

>>> trace.nodes["z"]  
{'type': 'sample', 'name': 'z', 'is_observed': False,
 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {},
 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (),
 'done': True, 'stop': False, 'continuation': None}

'infer' 是用户或算法指定的元数据的字典。 'args''kwargs' 是通过 pyro.sample 传递给 fn.__call__fn.log_prob 的参数。 'scale' 用于在计算对数联合概率时缩放站点的对数概率。 'cond_indep_stack' 包含与执行中出现的 pyro.plate 上下文对应的数据结构。 'done''stop''continuation' 仅由 Pyro 的内部使用。

Parameters

graph_type (string) – 指定要构建的跟踪图类型的字符串

add_edge(site1: str, site2: str) None[source]
add_node(site_name: str, **kwargs: Any) None[source]
Parameters

site_name (string) – 要添加的站点名称

将站点添加到跟踪中。

在尝试添加重复节点时引发错误,而不是静默覆盖。

compute_log_prob(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) None[source]

计算轨迹的逐点对数概率。 每个 log_prob 的形状等于相应的 batch_shape。 每个 log_prob_sum 是一个标量。 这两种计算都被缓存。

compute_score_parts() None[source]

计算跟踪中每个站点的批量局部分数部分。 每个log_prob的形状等于相应的batch_shape。 每个log_prob_sum是一个标量。 所有计算都被记忆化。

copy() pyro.poutine.trace_struct.Trace[source]

对自身进行浅拷贝,保留节点和边。

detach_() None[source]

在每个采样点分离值(就地)。

property edges: Iterable[Tuple[str, str]]
format_shapes(title: str = 'Trace Shapes:', last_site: Optional[str] = None) str[source]

返回一个字符串,显示跟踪中所有站点形状的表格。

iter_stochastic_nodes() Iterator[Tuple[str, Message]][source]
Returns

遍历跟踪中的随机节点的迭代器。

log_prob_sum(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) Union[torch.Tensor, float][source]

计算轨迹的逐点对数概率。 每个log_prob的形状等于相应的batch_shape。 每个log_prob_sum是一个标量。 log_prob_sum的计算是记忆化的。

Returns

总对数概率。

Return type

torch.Tensor

property nonreparam_stochastic_nodes: List[str]

一个样本站点名称的列表,其随机函数不可重新参数化为原始分布

Type

返回

property observation_nodes: List[str]

观测站点名称列表

Type

返回

pack_tensors(plate_to_symbol: Optional[Dict[str, str]] = None) None[source]

计算跟踪中张量的打包表示。 这应该在调用compute_log_prob()compute_score_parts()之后调用。

property param_nodes: List[str]

参数站点的名称列表

Type

返回

predecessors(site_name: str) Set[str][source]
remove_node(site_name: str) None[source]
property reparameterized_nodes: List[str]

一个样本站点名称的列表,其随机函数是可重新参数化的原始分布

Type

返回

property stochastic_nodes: List[str]

样本站点名称列表

Type

返回

successors(site_name: str) Set[str][source]
symbolize_dims(plate_to_symbol: Optional[Dict[str, str]] = None) None[source]

为所有张量维度分配唯一符号。

topological_sort(reverse: bool = False) List[str][source]

返回按拓扑排序顺序排列的节点(站点名称)列表。

Parameters

reverse (bool) – 以相反的顺序返回列表。

Returns

拓扑排序节点列表(站点名称)。

运行时

class InferDict[source]

基础:typing_extensions.TypedDict

一个包含推理信息的字典。

这可以用于配置每个站点的推理策略,例如:

pyro.sample(
    "x",
    dist.Bernoulli(0.5),
    infer={"enumerate": "parallel"},
)
Keys:
enumerate (str):

如果字符串是“sequential”或“parallel”之一,则启用枚举。并行枚举通常更快,但需要广播安全的操作和静态结构。

expand (bool):

是否在枚举期间扩展分布。如果缺失,默认为False。

is_auxiliary (bool):

样本站点是否为辅助站点,例如用于确定性地转换辅助变量的指南中。如果缺失,则默认为False。

is_observed (bool):

样本站点是否被观察到(即不是潜在的)。如果缺失,默认为False。

num_samples (int):

要抽取的样本数量。如果缺失,默认为1。

obs (optional torch.Tensor):

观测值,或对于潜在变量为None。如果缺失,默认为None。

prior (optional torch.distributions.Distribution):

(内部)用于在GuideMessenger中存储模型的先验分布(基于上游站点条件)。

tmc (str):

是否在TraceTMC_ELBO中使用对角线或混合近似进行张量蒙特卡洛。

was_observed (bool):

(内部)样本站点是否最初被观察到,在通过重加权唤醒睡眠或编译顺序重要性采样的推断背景下。

enumerate: typing_extensions.Literal[sequential, parallel]
expand: bool
is_auxiliary: bool
is_observed: bool
num_samples: int
obs: Optional[torch.Tensor]
prior: TorchDistributionMixin
tmc: typing_extensions.Literal[diagonal, mixture]
was_observed: bool
class Message[source]

基础类:typing_extensions.TypedDict, Generic[pyro.poutine.runtime._P, pyro.poutine.runtime._T]

Pyro 内部用于效果处理的消息类型。

消息存储在跟踪对象中,例如:

trace.nodes["my_site_name"]  # This is a Message.
Keys:
type (str):

消息类型,通常是字符串“sample”、“param”、“plate”或“markov”之一,但也可能是自定义的。

name (str):

站点名称,通常用于命名样本或参数。

fn (callable):

用于生成样本的分布或函数。

is_observed (bool):

一个标志,用于指示该值是否被观察到。

args (tuple):

分布或函数的位置参数。

kwargs (dict):

分布或函数的关键字参数。

value (torch.Tensor):

样本的值(无论是观察到的还是采样的)。

scale (torch.Tensor):

对数概率的缩放因子。

mask (bool torch.Tensor):

一个布尔值或张量,用于屏蔽对数概率。

cond_indep_stack (tuple):

站点的条件独立性元数据的本地堆栈。 不可变。

done (bool):

一个标志,用于指示消息是否已被处理。

stop (bool):

一个标志,用于停止进一步处理消息。

continuation (callable):

处理消息后调用的函数。

infer (optional InferDict):

推理参数的字典。

obs (torch.Tensor):

观察值。

log_prob (torch.Tensor):

样本的对数概率。

log_prob_sum (torch.Tensor):

对数概率的总和。

unscaled_log_prob (torch.Tensor):

未缩放的日志概率。

score_parts (pyro.distributions.ScoreParts):

一组分数部分。

packed (Message):

一个打包的消息,用于枚举期间。

args: Tuple
cond_indep_stack: Tuple[CondIndepStackFrame, ...]
continuation: Optional[Callable[[Message], None]]
done: bool
fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]
infer: Optional[pyro.poutine.runtime.InferDict]
is_observed: bool
kwargs: Dict
log_prob: torch.Tensor
log_prob_sum: torch.Tensor
mask: Optional[Union[bool, torch.Tensor]]
name: Optional[str]
obs: Optional[torch.Tensor]
packed: Message
scale: Union[torch.Tensor, float]
score_parts: ScoreParts
stop: bool
type: str
unscaled_log_prob: torch.Tensor
value: Optional[pyro.poutine.runtime._T]
exception NonlocalExit(site: pyro.poutine.runtime.Message, *args, **kwargs)[source]

基础类:Exception

用于从poutine执行中非局部退出的异常。

由 poutine.EscapeMessenger 使用以返回站点信息。

reset_stack() None[source]

重置堆栈中剩余帧的状态。 对于在poutine.queue中的多次重新执行是必要的。

am_i_wrapped() bool[source]

检查当前计算是否被包裹在一个poutine中。 :返回: bool

apply_stack(initial_msg: pyro.poutine.runtime.Message) None[source]

根据以下方案在单个站点执行效果堆栈:

  1. 对于堆栈中从下到上的每个Messenger, 使用消息执行Messenger._process_message; 如果消息字段“stop”为True,则停止; 否则,继续

  2. 应用默认行为(default_process_message)以完成剩余的站点执行

  3. 对于堆栈中的每个Messenger,从上到下执行_postprocess_message,以使用站点结果更新消息和内部信使状态

  4. 如果消息字段“continuation”不是None,则使用消息调用它

Parameters

initial_msg (dict) – 跟踪站点的初始版本

Returns

None

default_process_message(msg: pyro.poutine.runtime.Message) None[source]

推理中处理消息的默认方法。

Parameters

msg – 要处理的消息

Returns

effectful(fn: None = None, type: Optional[str] = None) Callable[[Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]], Callable[[...], pyro.poutine.runtime._T]][source]
effectful(fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T] = None, type: Optional[str] = None) Callable[[...], pyro.poutine.runtime._T]
Parameters
  • fn – 执行有效计算的函数或可调用对象

  • type (str) – 操作的类型标签,例如 “sample”

用于调用apply_stack()以应用任何活动效果的包装器。

get_mask() Optional[Union[bool, torch.Tensor]][source]

记录包含poutine.mask处理程序的效果。

这对于在预测期间避免昂贵的pyro.factor()计算非常有用,当不需要计算对数密度时,例如:

def model():
    # ...
    if poutine.get_mask() is not False:
        log_density = my_expensive_computation()
        pyro.factor("foo", log_density)
    # ...
Returns

掩码。

Return type

None, bool, 或 torch.Tensor

get_plates() Tuple[CondIndepStackFrame, ...][source]

记录包含pyro.plate上下文的效果。

Returns

一个由 pyro.poutine.indep_messenger.CondIndepStackFrame 对象组成的元组。

Return type

tuple

实用工具

all_escape(trace: Trace, msg: Message) bool[source]
Parameters
  • trace – 部分追踪

  • msg – Pyro 原始站点的消息

Returns

布尔决策值

实用函数,用于检查站点是否尚未在跟踪中。

由EscapeMessenger使用,用于决定是否在某个站点进行非本地退出。 用于近似积分变量以减少方差的子程序。

discrete_escape(trace: Trace, msg: Message) bool[source]
Parameters
  • trace – 部分追踪

  • msg – Pyro 原始站点的消息

Returns

布尔决策值

实用函数,用于检查采样点是否是离散的且尚未在跟踪中。

由EscapeMessenger使用,用于决定是否在某个站点进行非本地退出。 用于整合离散变量以减少方差的子程序。

enable_validation(is_validate: bool) None[source]
enum_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace][source]
Parameters
  • trace – 部分追踪

  • msg – Pyro 原始站点的消息

  • num_samples – 返回的最大扩展轨迹数。

Returns

一个轨迹列表,输入轨迹的副本,每个副本多一个站点

实用函数,用于复制并扩展一个基于输入站点的轨迹,其值从输入站点分布的支撑中枚举。

用于精确推断和整合离散变量。

is_validation_enabled() bool[source]
mc_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace][source]
Parameters
  • trace – 部分追踪

  • msg – Pyro 原始站点的消息

  • num_samples – 返回的最大扩展轨迹数。

Returns

一组轨迹,输入轨迹的副本,带有一个额外的站点

实用函数,用于复制和扩展基于输入站点的轨迹,其值从输入站点的函数中采样。

用于单个样本点的蒙特卡罗边际化。

prune_subsample_sites(trace: Trace) Trace[source]

复制并移除轨迹中的所有子样本点。

site_is_factor(site: Message) bool[source]

确定跟踪站点是否源自因子语句。

site_is_subsample(site: Message) bool[source]

确定一个跟踪站点是否源自plate内的子样本语句。

消息传递者

Messenger 对象包含由处理程序公开的效果的实现。 高级用户可以修改现有处理程序背后的 messenger 实现,或者编写新的 messenger, 这些 messenger 实现新的效果,并与库的其余部分正确组合。

信使

class Messenger[source]

基础类:object

上下文管理器类,用于修改行为 并向随机函数添加副作用 即包含Pyro原始语句的可调用对象。

这是基础的Messenger类。 它实现了所有Pyro原语的默认行为, 因此由随机函数fn引起的联合分布 与由Messenger()(fn)引起的联合分布相同。

用于推理过程中传递消息的转换器类。 大多数推理操作都在此类的子类中实现。

classmethod register(fn: Optional[Callable] = None, type: Optional[str] = None, post: Optional[bool] = None) Callable[source]
Parameters
  • fn – 实现操作的函数

  • type (str) – 操作的名称 (也传递给 effectful()

  • post (bool) – 如果为True,则将此操作用作后处理

动态地向效果添加操作。 对于生成库的包装器非常有用。

示例:

@SomeMessengerClass.register
def some_function(msg)
    ...do_something...
    return msg
classmethod unregister(fn: Optional[Callable] = None, type: Optional[str] = None) Optional[Callable][source]
Parameters
  • fn – 实现操作的函数

  • type (str) – 操作的名称 (也传递给 effectful())

动态地从效果中移除操作。 对于从库中移除包装器非常有用。

示例:

SomeMessengerClass.unregister(some_function, "name")
block_messengers(predicate: Callable[[pyro.poutine.messenger.Messenger], bool]) Iterator[List[pyro.poutine.messenger.Messenger]][source]

实验性的上下文管理器,用于临时从 _PYRO_STACK 中移除匹配的信使。注意这不会调用 .__exit__().__enter__() 方法。

这对于有选择性地阻止封闭处理程序非常有用。

Parameters

predicate (callable) – 一个将信使实例映射到布尔值的谓词。 这将静音所有满足 bool(predicate(m)) is True 的信使 m

Yields

被阻止的匹配信使列表。

unwrap(fn: Callable) Callable[source]

递归地展开poutines。

BlockMessenger

class BlockMessenger(hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None)[source]

基础类:pyro.poutine.messenger.Messenger

此处理程序选择性地从外部世界隐藏Pyro原始站点。 默认行为:阻止所有内容。

如果满足以下至少一个条件,则站点被隐藏:

  1. hide_fn(msg) is True(not expose_fn(msg)) is True

  2. msg["name"] in hide

  3. msg["type"] in hide_types

  4. msg["name"] not in expose and msg["type"] not in expose_types

  5. hide, hide_types, 和 expose_types 都是 None

例如,假设随机函数 fn 有两个样本点“a”和“b”。 那么在 BlockMessenger(fn, hide=["a"]) 之外的任何效果 都不会应用于点“a”,而只会看到点“b”:

>>> def fn():
...     a = pyro.sample("a", dist.Normal(0., 1.))
...     return pyro.sample("b", dist.Normal(a, 1.))
>>> fn_inner = pyro.poutine.trace(fn)
>>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"]))
>>> trace_inner = fn_inner.get_trace()
>>> trace_outer  = fn_outer.get_trace()
>>> "a" in trace_inner
True
>>> "a" in trace_outer
False
>>> "b" in trace_inner
True
>>> "b" in trace_outer
True
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • hide_fn – 一个函数,它接受一个站点并返回True以隐藏该站点,或返回False/None以暴露它。如果指定了此函数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。

  • expose_fn – 一个函数,接受一个站点并返回True以暴露该站点,或返回False/None以隐藏它。如果指定了此参数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。

  • hide_all (bool) – 隐藏所有站点

  • expose_all (bool) – 正常暴露所有站点

  • 隐藏 (list) – 要隐藏的站点名称列表

  • expose (list) – 要暴露的站点名称列表,其他站点将被隐藏

  • hide_types (list) – 要隐藏的站点类型列表

  • expose_types (list) – 要暴露的站点类型列表,其他所有类型将被隐藏

Returns

使用BlockMessenger装饰的随机函数

广播信使

class BroadcastMessenger[source]

基础类:pyro.poutine.messenger.Messenger

当在单个或嵌套的plate上下文中时,自动广播随机函数在采样点的批次形状。现有的batch_shape必须与安装在cond_indep_stack中的plate上下文的大小可广播。

注意下面的model_automatic_broadcast如何自动扩展分布批处理形状。这使得将Pyro模型模块化变得容易,因为子组件对包装的plate上下文是无感知的。

>>> def model_broadcast_by_hand():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5)
...                                                .expand_by(100))
...             assert sample.shape == torch.Size((100, 3))
...     return sample
>>> @poutine.broadcast
... def model_automatic_broadcast():
...     with IndepMessenger("batch", 100, dim=-2):
...         with IndepMessenger("components", 3, dim=-1):
...             sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5)))
...             assert sample.shape == torch.Size((100, 3))
...     return sample

折叠消息器

class CollapseMessenger(*args: Any, **kwargs: Any)[source]

基础类:pyro.poutine.trace_messenger.TraceMessenger

实验性功能:通过懒采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受Funsors而不是Tensors。这需要安装funsor

警告

这与自动猜测max_plate_nesting不兼容。如果在折叠的上下文中出现任何板块,您应该手动向您的推理算法声明max_plate_nesting(例如Trace_ELBO(max_plate_nesting=1))。

条件信使

class ConditionMessenger(data: Union[Dict[str, torch.Tensor], pyro.poutine.trace_struct.Trace])[source]

基础类:pyro.poutine.messenger.Messenger

给定一个包含一些样本语句的随机函数和一个名称的观察字典, 将那些名称处的样本语句更改为具有这些值的观察。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要观察站点 z 的值,我们可以写

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

这相当于在model中将obs=value作为关键字参数添加到pyro.sample(“z”, …)中。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • data – 一个字典或一个 Trace

Returns

使用ConditionMessenger装饰的随机函数

DoMessenger

class DoMessenger(data: Dict[str, Union[torch.Tensor, numbers.Number]])[source]

基础类:pyro.poutine.messenger.Messenger

给定一个带有一些样本语句的随机函数和一个名称值的字典, 将这些站点的返回值设置为这些值, 就像它们被硬编码为这些值一样, 并引入具有相同名称的新样本站点, 这些站点的值不会传播。

condition()自由组合,以表示潜在结果的反事实分布。 有关更多细节和理论,请参见单世界干预图[1]。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要为站点 z 干预一个值,我们可以写

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

这相当于将z = pyro.sample(“z”, …)替换为 z = torch.tensor(1.) 并引入一个新的采样点pyro.sample(“z”, …),其值在其他地方未使用。

参考文献

[1] Single World Intervention Graphs: A Primer,

托马斯·理查森,詹姆斯·罗宾斯

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • data – 一个将样本站点名称映射到干预措施的 dict

Returns

使用DoMessenger装饰的随机函数

枚举消息传递器

class EnumMessenger(first_available_dim: Optional[int] = None)[source]

基础类:pyro.poutine.messenger.Messenger

并行枚举标记为infer={"enumerate": "parallel"}的离散样本点。

Parameters

first_available_dim (int) – 第一个可用于并行枚举的张量维度(从右开始计数)。这个维度和所有左侧的维度可能会被Pyro内部使用。这应该是一个负整数或None。

enumerate_site(msg: pyro.poutine.runtime.Message) torch.Tensor[source]

EscapeMessenger

class EscapeMessenger(escape_fn: Callable[[pyro.poutine.runtime.Message], bool])[source]

基础类:pyro.poutine.messenger.Messenger

通过引发 util.NonlocalExit 异常来进行非局部退出的 Messenger

独立消息传递器

class CondIndepStackFrame(name, dim, size, counter, full_size)[source]

基础类:tuple

counter: int

字段编号3的别名

dim: Optional[int]

字段编号1的别名

full_size: Optional[int]

字段编号4的别名

name: str

字段编号 0 的别名

size: int

字段编号2的别名

property vectorized: bool
class IndepMessenger(name: str, size: int, dim: Optional[int] = None, device: Optional[str] = None)[源代码]

基础类:pyro.poutine.messenger.Messenger

这个信使跟踪由嵌套的plate上下文声明的独立性信息堆栈。这些信息存储在cond_indep_stack中,在每个采样/观察站点供TraceMessenger使用。

示例:

x_axis = IndepMessenger('outer', 320, dim=-1)
y_axis = IndepMessenger('inner', 200, dim=-2)
with x_axis:
    x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320]))
with y_axis:
    y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1]))
with x_axis, y_axis:
    xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))
property indices: torch.Tensor
next_context() None[source]

增加计数器。

InferConfigMessenger

class InferConfigMessenger(config_fn: Callable[[Message], InferDict])[source]

基础类:pyro.poutine.messenger.Messenger

给定一个包含Pyro原始调用的可调用对象fn和一个接受跟踪站点并返回字典的可调用对象config_fn,将样本站点的infer kwarg的值更新为config_fn(site)。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • config_fn – 一个可调用对象,接收一个站点并返回一个推断字典

Returns

使用InferConfigMessenger装饰的随机函数

LiftMessenger

class LiftMessenger(prior: Union[Callable, pyro.distributions.distribution.Distribution, Dict[str, Union[pyro.distributions.distribution.Distribution, Callable]]])[源代码]

基础类:pyro.poutine.messenger.Messenger

给定一个带有参数调用的随机函数和一个先验分布,创建一个随机函数,其中所有参数调用都被从先验中采样所替代。先验应该是一个可调用对象或一个名称到可调用对象的字典。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2
>>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})

lift 使 param 语句的行为类似于 sample 语句 使用 prior 中的分布。在这个例子中,站点 s 现在将表现得 好像它被替换为 s = pyro.sample("s", dist.Exponential(0.3)):

>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> tr.nodes["s"]["type"] == "sample"
True
>>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0)
>>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all())
False
Parameters
  • fn – 其参数将被提升为随机值的函数

  • prior – 以Distribution形式或随机函数字典形式存在的先验函数

Returns

fn 使用 LiftMessenger 装饰

马尔科夫信使

class MarkovMessenger(history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None)[source]

基础类:pyro.poutine.reentrant_messenger.ReentrantMessenger

马尔可夫依赖声明。

这是一个统计上等同于内存管理区域的概念。

Parameters
  • 历史 (int) – 从当前上下文中可见的先前上下文的数量。默认为1。如果为零,这类似于 pyro.plate

  • keep (bool) – 如果为真,帧是可重放的。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(取决于它们的共同祖先)。

  • dim (int) – 用于此独立性索引的可选维度。 接口存根,行为尚未实现。

  • name (str) – 一个可选的唯一名称,用于帮助推理算法在模型和指南之间匹配 pyro.markov() 站点。 接口存根,行为尚未实现。

generator(iterable: Iterable[int]) typing_extensions.Self[source]

MaskMessenger

class MaskMessenger(mask: Union[bool, torch.BoolTensor])[source]

基础类:pyro.poutine.messenger.Messenger

给定一个带有一些批量样本语句和掩码张量的随机函数,按元素掩码掉一些样本语句。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • mask (torch.BoolTensor) – 一个{0,1}值的掩码张量 (1 包含一个站点,0 排除一个站点)

Returns

使用MaskMessenger装饰的随机函数

PlateMessenger

class PlateMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

基础类:pyro.poutine.subsample_messenger.SubsampleMessenger

广播功能的瑞士军刀: 结合了形状推断、独立性注释和子采样

block_plate(name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True) Iterator[None][source]

实验性的上下文管理器,用于临时阻止单个封闭的板块。

这对于采样辅助变量或在需要时延迟采样全局变量非常有用。例如,以下模型是等效的:

示例:

def model_1(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    with pyro.plate("data", len(data)):
        with block_plate("data"):
            scale = pyro.sample("scale", dist.LogNormal(0, 1))
        pyro.sample("x", dist.Normal(loc, scale))

def model_2(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    scale = pyro.sample("scale", dist.LogNormal(0, 1))
    with pyro.plate("data", len(data)):
        pyro.sample("x", dist.Normal(loc, scale))
Parameters
  • name (str) – 可选的板块名称以匹配。

  • dim (int) – 可选的板子维度以匹配。必须为负数。

  • strict (bool) – 如果未找到匹配的板块,是否报错。默认为 True。

Raises

如果没有找到封闭的板块并且strict=True,则抛出ValueError。

可重入信使

class ReentrantMessenger[源代码]

基础类:pyro.poutine.messenger.Messenger

ReparamMessenger

class ReparamHandler(msngr, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]

基础类: Generic[pyro.poutine.reparam_messenger._P, pyro.poutine.reparam_messenger._T]

重新参数化程序。

class ReparamMessenger(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]])[源代码]

基础类:pyro.poutine.messenger.Messenger

将每个受影响的采样点重新参数化为一个或多个辅助采样点,然后进行确定性变换 [1]。

要指定重新参数化器,请将config字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参见pyro.infer.reparam模块。

请注意,一些重新参数化器可以检查它们影响的函数的*args,**kwargs输入;这些重新参数化器需要使用poutine.reparam作为装饰器而不是上下文管理器。

[1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)

“概率程序的自动重新参数化” https://arxiv.org/pdf/1906.03028.pdf

Parameters

config (dictcallable) – 配置,可以是一个将站点名称映射到 Reparameterizer 的字典,或者是一个将站点映射到 Reparam 或 None 的函数。有关内置的配置策略,请参见 pyro.infer.reparam.strategies

重放消息器

class ReplayMessenger(trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None)[source]

基础类:pyro.poutine.messenger.Messenger

给定一个包含Pyro原始调用的可调用对象, 返回一个运行原始可调用对象的可调用对象,在新跟踪中重用跟踪中的值

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

replay 使得 sample 语句表现得好像它们在跟踪中的相应位置采样了值一样:

>>> old_trace = pyro.poutine.trace(model).get_trace(1.0)
>>> replayed_model = pyro.poutine.replay(model, trace=old_trace)
>>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"])
True
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • trace – 一个用于回放的Trace数据结构

  • params – 参数字典,包含参数名称和约束值,用于在函数中重放

Returns

一个用ReplayMessenger装饰的随机函数

ScaleMessenger

class ScaleMessenger(scale: Union[float, torch.Tensor])[source]

基础类:pyro.poutine.messenger.Messenger

给定一个包含一些样本语句和正比例因子的随机函数,缩放函数中所有样本和观察点的分数。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scale 以乘法方式缩放样本站点的对数概率:

>>> scaled_model = pyro.poutine.scale(model, scale=0.5)
>>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0)
>>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0)
>>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all())
True
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • scale – 一个正的比例因子

Returns

使用ScaleMessenger装饰的随机函数

SeedMessenger

class SeedMessenger(rng_seed: int)[source]

基础类:pyro.poutine.messenger.Messenger

处理程序通过设置种子将随机数生成器设置为预定义状态。这与在调用fn之前调用pyro.set_rng_seed()相同。此处理程序对标准Pyro后端上的原始语句没有额外影响,但它可能会在其他后端(例如NumPy后端)中拦截pyro.sample调用。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)。

  • rng_seed (int) – 随机数生成器种子。

子样本信使

class SubsampleMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

基础类: pyro.poutine.indep_messenger.IndepMessenger

IndepMessenger 的扩展,包括子采样。

替代信使

class SubstituteMessenger(data: Dict[str, torch.Tensor])[source]

基础类:pyro.poutine.messenger.Messenger

给定一个带有参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都被替换为固定值。数据应该是一个名称到值的字典。考虑以下Pyro程序:

>>> def model(x):
...     a = pyro.param("a", torch.tensor(0.5))
...     x = pyro.sample("x", dist.Bernoulli(probs=a))
...     return x
>>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})

在这个例子中,站点 a 现在将具有值 torch.tensor(0.3)。 :param data: 由站点名称键控的值的字典。 :returns: fn 装饰有 SubstituteMessenger

TraceMessenger

class TraceHandler(msngr: pyro.poutine.trace_messenger.TraceMessenger, fn: Callable[[pyro.poutine.trace_messenger._P], pyro.poutine.trace_messenger._T])[源代码]

基础类: Generic[pyro.poutine.trace_messenger._P, pyro.poutine.trace_messenger._T]

执行跟踪程序。

TraceHandler 记录每个 Pyro 原语的输入和输出,并将它们存储为 Trace() 中的一个站点。 理论上,这应该足以提供每个推理算法所需的信息 (以及 Variables 中的隐式计算图?)

我们也可以将其用于可视化。

get_trace(*args, **kwargs) pyro.poutine.trace_struct.Trace[源代码]
Returns

数据结构

Return type

pyro.poutine.Trace

用于非常常见的使用场景的辅助方法。 调用此poutine并返回其跟踪,而不是函数的返回值。

property trace: pyro.poutine.trace_struct.Trace
class TraceMessenger(graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None)[源代码]

基础类:pyro.poutine.messenger.Messenger

返回一个记录原始调用输入输出及其依赖关系的处理程序。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

我们可以使用trace记录其执行过程,并使用生成的数据结构来计算执行过程中所有采样点的对数联合概率或提取所有参数。

>>> trace = pyro.poutine.trace(model).get_trace(0.0)
>>> logp = trace.log_prob_sum()
>>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • graph_type – 指定要构建的图形类型的字符串

  • param_only – 如果为true,仅记录参数而不记录样本

Returns

使用TraceMessenger装饰的随机函数

get_trace() pyro.poutine.trace_struct.Trace[source]
Returns

数据结构

Return type

pyro.poutine.Trace

用于非常常见用例的辅助方法。 返回self.trace的浅拷贝。

identify_dense_edges(trace: pyro.poutine.trace_struct.Trace) None[source]

通过根据每个站点存储的cond_indep_stack信息添加所有边来就地修改跟踪。

无条件信使

class UnconditionMessenger[source]

基础类:pyro.poutine.messenger.Messenger

Messenger 强制从它们的分布中采样观察节点的值,忽略观察结果。

指南信使

class GuideMessenger(model: Callable)[source]

基础类: pyro.poutine.trace_messenger.TraceMessenger, abc.ABC

基于效果的指南的抽象基类。

派生类必须实现 get_posterior() 方法。

property model: Callable
__call__(*args, **kwargs) Dict[str, torch.Tensor][source]

从指南中抽取后验样本,并针对这些样本重放模型。

Returns

一个将样本站点名称映射到样本值的字典。 这包括潜在值、确定值和观测值。

Return type

dict

abstract get_posterior(name: str, prior: TorchDistributionMixin) Union[TorchDistributionMixin, torch.Tensor][source]

抽象方法,用于计算后验分布或在给定上游后验样本的条件下采样后验值。

实现可以在此函数内部使用pyro.parampyro.sample,但pyro.sample语句应设置infer={"is_auxiliary": True"}

实现可能会访问更多信息以进行计算:

  • value = self.upstream_value(name) is the value of an upstream

    样本或确定性站点。

  • self.trace 是上游站点的跟踪信息,可能对其他信息有用,例如 self.trace.nodes["my_site"]["fn"]self.trace.nodes["my_site"]["cond_indep_stack"]

  • args, kwargs = self.args_kwargs are the inputs to the model, and

    可能对摊销有用。

Parameters
  • name (str) – 要采样的样本站点的名称。

  • prior (Distribution) – 该样本点的先验分布 (以后验的上游样本为条件)。

Returns

后验分布或来自后验分布的样本。

Return type

Distributiontorch.Tensor

upstream_value(name: str) Optional[torch.Tensor][source]

用于 get_posterior()

Returns

上游样本或确定性站点的值

Return type

torch.Tensor

get_traces() Tuple[pyro.poutine.trace_struct.Trace, pyro.poutine.trace_struct.Trace][source]

这可以在运行__call__()之后调用,以提取一对轨迹。

与生成一对轨迹的跟踪重放模式不同, GuideMessenger 交错进行模型和指导计算,因此 只需要一次 guide(*args, **kwargs) 调用即可创建两个 轨迹。此函数仅从该指导的 .trace 属性中提取相关信息。

Returns

一对 (model_trace, guide_trace)

Return type

tuple