普丁(效果处理器)¶
在Pyro内置的推理算法之下,有一个可组合的效果处理程序库,用于创建新的推理算法和处理概率程序。Pyro的推理算法都是通过将这些处理程序应用于随机函数来构建的。为了对效果处理程序是什么以及它们解决了什么问题有一个全面的了解,请阅读Matija Pretnar的代数效应和处理程序简介。
处理器¶
Poutine 是一个用于记录和修改 Pyro 程序行为的可组合效果处理程序库。这些较低层次的组件简化了新推理算法和行为的实现。
处理程序可以用作高阶函数、装饰器或上下文管理器,以修改函数或代码块的行为:
例如,考虑以下Pyro程序:
>>> def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
我们可以使用condition
将样本站点标记为已观察,它返回一个与model
具有相同输入和输出签名的可调用对象:
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
我们也可以使用处理程序作为装饰器:
>>> @pyro.condition(data={"z": 1.0})
... def model(x):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(x, s))
... return z ** 2
或者作为上下文管理器:
>>> with pyro.condition(data={"z": 1.0}):
... s = pyro.param("s", torch.tensor(0.5))
... z = pyro.sample("z", dist.Normal(0., s))
... y = z ** 2
处理程序可以自由组合:
>>> conditioned_model = poutine.condition(model, data={"z": 1.0})
>>> traced_model = poutine.trace(conditioned_model)
许多推理算法或算法组件可以用几行代码实现:
guide_tr = poutine.trace(guide).get_trace(...)
model_tr = poutine.trace(poutine.replay(conditioned_model, trace=guide_tr)).get_trace(...)
monte_carlo_elbo = model_tr.log_prob_sum() - guide_tr.log_prob_sum()
- block(fn: None = None, hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) pyro.poutine.block_messenger.BlockMessenger [source]¶
- block(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
BlockMessenger
的便捷包装器此处理程序选择性地从外部世界隐藏Pyro原始站点。 默认行为:阻止所有内容。
如果满足以下至少一个条件,则站点被隐藏:
hide_fn(msg) is True
或(not expose_fn(msg)) is True
msg["name"] in hide
msg["type"] in hide_types
msg["name"] not in expose and msg["type"] not in expose_types
hide
,hide_types
, 和expose_types
都是None
例如,假设随机函数 fn 有两个样本点“a”和“b”。 那么在
BlockMessenger(fn, hide=["a"])
之外的任何效果 都不会应用于点“a”,而只会看到点“b”:>>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = pyro.poutine.trace(fn) >>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
hide_fn – 一个函数,它接受一个站点并返回True以隐藏该站点,或返回False/None以暴露它。如果指定了此函数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。
expose_fn – 一个函数,接受一个站点并返回True以暴露该站点,或返回False/None以隐藏它。如果指定了此参数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。
hide_all (bool) – 隐藏所有站点
expose_all (bool) – 正常暴露所有站点
隐藏 (list) – 要隐藏的站点名称列表
expose (list) – 要暴露的站点名称列表,其他站点将被隐藏
hide_types (list) – 要隐藏的站点类型列表
expose_types (list) – 要暴露的站点类型列表,其他所有类型将被隐藏
- Returns
使用
BlockMessenger
装饰的随机函数
- broadcast(fn: None = None) pyro.poutine.broadcast_messenger.BroadcastMessenger [source]¶
- broadcast(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
BroadcastMessenger
的便捷包装器当在单个或嵌套的plate上下文中时,自动广播随机函数在采样点的批次形状。现有的batch_shape必须与安装在cond_indep_stack中的
plate
上下文的大小可广播。注意下面的model_automatic_broadcast如何自动扩展分布批处理形状。这使得将Pyro模型模块化变得容易,因为子组件对包装的
plate
上下文是无感知的。>>> def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample
>>> @poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample
- collapse(fn: None = None, *args: Any, **kwargs: Any) pyro.poutine.collapse_messenger.CollapseMessenger [source]¶
- collapse(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], *args: Any, **kwargs: Any) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
CollapseMessenger
的便捷包装器实验性功能:通过懒采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受Funsors而不是Tensors。这需要安装
funsor
。警告
这与自动猜测
max_plate_nesting
不兼容。如果在折叠的上下文中出现任何板块,您应该手动向您的推理算法声明max_plate_nesting
(例如Trace_ELBO(max_plate_nesting=1)
)。
- condition(data: Union[Dict[str, torch.Tensor], Trace]) pyro.poutine.condition_messenger.ConditionMessenger [source]¶
- condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
ConditionMessenger
的便捷包装器给定一个包含一些样本语句的随机函数和一个名称的观察字典, 将那些名称处的样本语句更改为具有这些值的观察。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要观察站点 z 的值,我们可以写
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
这相当于在model中将obs=value作为关键字参数添加到pyro.sample(“z”, …)中。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
data – 一个字典或一个
Trace
- Returns
使用
ConditionMessenger
装饰的随机函数
- do(data: Dict[str, Union[torch.Tensor, numbers.Number]]) pyro.poutine.do_messenger.DoMessenger [source]¶
- do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
DoMessenger
的便捷包装器给定一个带有一些样本语句的随机函数和一个名称值的字典, 将这些站点的返回值设置为这些值, 就像它们被硬编码为这些值一样, 并引入具有相同名称的新样本站点, 这些站点的值不会传播。
与
condition()
自由组合,以表示潜在结果的反事实分布。 有关更多细节和理论,请参见单世界干预图[1]。考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要为站点 z 干预一个值,我们可以写
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
这相当于将z = pyro.sample(“z”, …)替换为 z = torch.tensor(1.) 并引入一个新的采样点pyro.sample(“z”, …),其值在其他地方未使用。
参考文献
- [1] Single World Intervention Graphs: A Primer,
托马斯·理查森,詹姆斯·罗宾斯
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
data – 一个将样本站点名称映射到干预措施的
dict
- Returns
使用
DoMessenger
装饰的随机函数
- enum(fn: None = None, first_available_dim: Optional[int] = None) pyro.poutine.enum_messenger.EnumMessenger [source]¶
- enum(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], first_available_dim: Optional[int] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
EnumMessenger
的便捷包装器并行枚举标记为
infer={"enumerate": "parallel"}
的离散样本点。- Parameters
first_available_dim (int) – 第一个可用于并行枚举的张量维度(从右开始计数)。这个维度和所有左侧的维度可能会被Pyro内部使用。这应该是一个负整数或None。
- escape(escape_fn: Callable[[Message], bool]) pyro.poutine.escape_messenger.EscapeMessenger [source]¶
- escape(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], escape_fn: Callable[[Message], bool]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
EscapeMessenger
的便捷包装器通过引发 util.NonlocalExit 异常来进行非局部退出的 Messenger
- infer_config(config_fn: Callable[[Message], InferDict]) pyro.poutine.infer_config_messenger.InferConfigMessenger [source]¶
- infer_config(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config_fn: Callable[[Message], InferDict]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
InferConfigMessenger
的便捷包装器给定一个包含Pyro原始调用的可调用对象fn和一个接受跟踪站点并返回字典的可调用对象config_fn,将样本站点的infer kwarg的值更新为config_fn(site)。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
config_fn – 一个可调用对象,接收一个站点并返回一个推断字典
- Returns
使用
InferConfigMessenger
装饰的随机函数
- lift(prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) pyro.poutine.lift_messenger.LiftMessenger [source]¶
- lift(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], prior: Union[Callable, Distribution, Dict[str, Union[Distribution, Callable]]]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
LiftMessenger
的便捷包装器给定一个带有参数调用的随机函数和一个先验分布,创建一个随机函数,其中所有参数调用都被从先验中采样所替代。先验应该是一个可调用对象或一个名称到可调用对象的字典。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
lift
使param
语句的行为类似于sample
语句 使用prior
中的分布。在这个例子中,站点 s 现在将表现得 好像它被替换为s = pyro.sample("s", dist.Exponential(0.3))
:>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False
- Parameters
fn – 其参数将被提升为随机值的函数
prior – 以Distribution形式或随机函数字典形式存在的先验函数
- Returns
fn
使用LiftMessenger
装饰
- markov(fn: None = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger [source]¶
- markov(fn: Iterable[int] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) pyro.poutine.markov_messenger.MarkovMessenger
- markov(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None, history: int = 1, keep: bool = False, dim: Optional[int] = None, name: Optional[str] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
马尔可夫依赖声明。
这可以用于多种方式:
作为上下文管理器
作为递归函数的装饰器
作为马尔可夫链的迭代器
- mask(mask: Union[bool, torch.BoolTensor]) pyro.poutine.mask_messenger.MaskMessenger [source]¶
- mask(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], mask: Union[bool, torch.BoolTensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
MaskMessenger
的便捷包装器给定一个带有一些批量样本语句和掩码张量的随机函数,按元素掩码掉一些样本语句。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
mask (torch.BoolTensor) – 一个
{0,1}
值的掩码张量 (1 包含一个站点,0 排除一个站点)
- Returns
使用
MaskMessenger
装饰的随机函数
- queue(fn=None, queue=None, max_tries=None, extend_fn=None, escape_fn=None, num_samples=None)[source]¶
用于离散变量的顺序枚举。
给定一个随机函数和一个队列,从队列中的完整跟踪返回一个返回值。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
queue – 一个类似于 multiprocessing.Queue 的队列数据结构,用于保存部分跟踪信息
max_tries – 计算单个完整跟踪的最大尝试次数
extend_fn – 函数(可能是随机的),它接受部分轨迹和一个站点,并返回扩展轨迹的列表
escape_fn – 函数(可能是随机的),它接受部分轨迹和一个站点,并返回一个布尔值来决定是否退出
num_samples – extend_fn 返回的扩展轨迹的可选数量
- Returns
使用poutine逻辑装饰的随机函数
- reparam(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamMessenger [source]¶
- reparam(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]]) pyro.poutine.reparam_messenger.ReparamHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]
ReparamMessenger
的便捷包装器将每个受影响的采样点重新参数化为一个或多个辅助采样点,然后进行确定性变换 [1]。
要指定重新参数化器,请将
config
字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参见pyro.infer.reparam
模块。请注意,一些重新参数化器可以检查它们影响的函数的
*args,**kwargs
输入;这些重新参数化器需要使用poutine.reparam
作为装饰器而不是上下文管理器。- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“概率程序的自动重新参数化” https://arxiv.org/pdf/1906.03028.pdf
- Parameters
config (dict 或 callable) – 配置,可以是一个将站点名称映射到
Reparameterizer
的字典,或者是一个将站点映射到Reparam
或 None 的函数。有关内置的配置策略,请参见pyro.infer.reparam.strategies
。
- replay(fn: None = None, trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) pyro.poutine.replay_messenger.ReplayMessenger [source]¶
- replay(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
ReplayMessenger
的便捷包装器给定一个包含Pyro原始调用的可调用对象, 返回一个运行原始可调用对象的可调用对象,在新跟踪中重用跟踪中的值
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
使得sample
语句表现得好像它们在跟踪中的相应位置采样了值一样:>>> old_trace = pyro.poutine.trace(model).get_trace(1.0) >>> replayed_model = pyro.poutine.replay(model, trace=old_trace) >>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
trace – 一个用于回放的
Trace
数据结构params – 参数字典,包含参数名称和约束值,用于在函数中重放
- Returns
一个用
ReplayMessenger
装饰的随机函数
- scale(scale: Union[float, torch.Tensor]) pyro.poutine.scale_messenger.ScaleMessenger [source]¶
- scale(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], scale: Union[float, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
ScaleMessenger
的便捷包装器给定一个包含一些样本语句和正比例因子的随机函数,缩放函数中所有样本和观察点的分数。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scale
以乘法方式缩放样本站点的对数概率:>>> scaled_model = pyro.poutine.scale(model, scale=0.5) >>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) >>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) >>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
scale – 一个正的比例因子
- Returns
使用
ScaleMessenger
装饰的随机函数
- seed(rng_seed: int) pyro.poutine.seed_messenger.SeedMessenger [source]¶
- seed(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], rng_seed: int) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
SeedMessenger
的便捷包装器处理程序通过设置种子将随机数生成器设置为预定义状态。这与在调用fn之前调用
pyro.set_rng_seed()
相同。此处理程序对标准Pyro后端上的原始语句没有额外影响,但它可能会在其他后端(例如NumPy后端)中拦截pyro.sample
调用。- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)。
rng_seed (int) – 随机数生成器种子。
- substitute(data: Dict[str, torch.Tensor]) pyro.poutine.substitute_messenger.SubstituteMessenger [source]¶
- substitute(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, torch.Tensor]) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
SubstituteMessenger
的便捷包装器给定一个带有参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都被替换为固定值。数据应该是一个名称到值的字典。考虑以下Pyro程序:
>>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
在这个例子中,站点 a 现在将具有值 torch.tensor(0.3)。 :param data: 由站点名称键控的值的字典。 :returns:
fn
装饰有SubstituteMessenger
- trace(fn: None = None, graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceMessenger [source]¶
- trace(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None) pyro.poutine.trace_messenger.TraceHandler[pyro.poutine.handlers._P, pyro.poutine.handlers._T]
TraceMessenger
的便捷包装器返回一个记录原始调用输入输出及其依赖关系的处理程序。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
我们可以使用
trace
记录其执行过程,并使用生成的数据结构来计算执行过程中所有采样点的对数联合概率或提取所有参数。>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
graph_type – 指定要构建的图形类型的字符串
param_only – 如果为true,仅记录参数而不记录样本
- Returns
使用
TraceMessenger
装饰的随机函数
- uncondition(fn: None = None) pyro.poutine.uncondition_messenger.UnconditionMessenger [source]¶
- uncondition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T] = None) Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]
UnconditionMessenger
的便捷包装器Messenger 强制从它们的分布中采样观察节点的值,忽略观察结果。
- config_enumerate(guide=None, default='parallel', expand=False, num_samples=None, tmc='diagonal')[source]¶
配置指南中所有相关站点的枚举。这主要与
TraceEnum_ELBO
一起使用。在配置离散变量的穷举枚举时,这会配置所有分布满足
.has_enumerate_support == True
的采样点。 在通过default="parallel", num_samples=n
配置本地并行蒙特卡洛采样时,这会配置所有采样点。 这不会覆盖现有的注释infer={"enumerate": ...}
。这可以用作一个函数:
guide = config_enumerate(guide)
或者作为装饰器:
@config_enumerate def guide1(*args, **kwargs): ... @config_enumerate(default="sequential", expand=True) def guide2(*args, **kwargs): ...
- Parameters
指南 (可调用) – 一个pyro模型,将用作
SVI
中的指南。默认 (str) – 使用哪种枚举策略,可选“sequential”(顺序)、“parallel”(并行)或None。默认为“parallel”(并行)。
expand (bool) – 是否扩展枚举的样本值。详情请参见
enumerate_support()
。 这仅适用于详尽枚举,其中num_samples=None
。 如果num_samples
不是None
,则样本将始终被扩展。num_samples (int 或 None) – 如果不是
None
,则使用局部蒙特卡洛采样而不是穷举枚举。这对于连续和离散分布都有意义。tmc (string 或 None) – 在Tensor Monte Carlo中使用的“mixture”或“diagonal”策略
- Returns
一个带注释的指南
- Return type
可调用的
追踪¶
- class Trace(graph_type: Literal['flat', 'dense'] = 'flat')[source]¶
基础类:
object
表示执行跟踪中不同pyro原语之间关系的图数据结构。
Pyro程序的执行跟踪是记录该程序单次执行中对
pyro.sample()
和pyro.param()
的每次调用。跟踪是有向图,其节点表示原始调用或输入/输出,边表示这些原始调用之间的条件依赖关系。它们由poutine.trace
创建和填充。跟踪中的每个节点(或站点)包含站点的名称、输入和输出值,以及由推理算法或用户注释添加的额外元数据。在
pyro.sample
的情况下,跟踪还包括站点上的随机函数,以及用户添加的任何观察数据。考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
我们可以使用
pyro.poutine.trace
记录其执行过程,并使用生成的数据结构来计算执行过程中所有采样点的对数联合概率或提取所有参数。>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
我们还可以检查或操作跟踪中的单个节点。
trace.nodes
包含一个collections.OrderedDict
的站点名称和元数据,对应于x
,s
,z
, 以及返回值:>>> list(name for name in trace.nodes.keys()) ["_INPUT", "s", "z", "_RETURN"]
trace.nodes
的值是节点元数据的字典:>>> trace.nodes["z"] {'type': 'sample', 'name': 'z', 'is_observed': False, 'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {}, 'infer': {}, 'scale': 1.0, 'cond_indep_stack': (), 'done': True, 'stop': False, 'continuation': None}
'infer'
是用户或算法指定的元数据的字典。'args'
和'kwargs'
是通过pyro.sample
传递给fn.__call__
或fn.log_prob
的参数。'scale'
用于在计算对数联合概率时缩放站点的对数概率。'cond_indep_stack'
包含与执行中出现的pyro.plate
上下文对应的数据结构。'done'
、'stop'
和'continuation'
仅由 Pyro 的内部使用。- Parameters
graph_type (string) – 指定要构建的跟踪图类型的字符串
- add_node(site_name: str, **kwargs: Any) None [source]¶
- Parameters
site_name (string) – 要添加的站点名称
将站点添加到跟踪中。
在尝试添加重复节点时引发错误,而不是静默覆盖。
- compute_log_prob(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) None [source]¶
计算轨迹的逐点对数概率。 每个
log_prob
的形状等于相应的batch_shape
。 每个log_prob_sum
是一个标量。 这两种计算都被缓存。
- compute_score_parts() None [source]¶
计算跟踪中每个站点的批量局部分数部分。 每个
log_prob
的形状等于相应的batch_shape
。 每个log_prob_sum
是一个标量。 所有计算都被记忆化。
- copy() pyro.poutine.trace_struct.Trace [source]¶
对自身进行浅拷贝,保留节点和边。
- format_shapes(title: str = 'Trace Shapes:', last_site: Optional[str] = None) str [source]¶
返回一个字符串,显示跟踪中所有站点形状的表格。
- log_prob_sum(site_filter: Callable[[str, Message], bool] = <function allow_all_sites>) Union[torch.Tensor, float] [source]¶
计算轨迹的逐点对数概率。 每个
log_prob
的形状等于相应的batch_shape
。 每个log_prob_sum
是一个标量。log_prob_sum
的计算是记忆化的。- Returns
总对数概率。
- Return type
- pack_tensors(plate_to_symbol: Optional[Dict[str, str]] = None) None [source]¶
计算跟踪中张量的打包表示。 这应该在调用
compute_log_prob()
或compute_score_parts()
之后调用。
运行时¶
- class InferDict[source]¶
基础:
typing_extensions.TypedDict
一个包含推理信息的字典。
这可以用于配置每个站点的推理策略,例如:
pyro.sample( "x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, )
- Keys:
- enumerate (str):
如果字符串是“sequential”或“parallel”之一,则启用枚举。并行枚举通常更快,但需要广播安全的操作和静态结构。
- expand (bool):
是否在枚举期间扩展分布。如果缺失,默认为False。
- is_auxiliary (bool):
样本站点是否为辅助站点,例如用于确定性地转换辅助变量的指南中。如果缺失,则默认为False。
- is_observed (bool):
样本站点是否被观察到(即不是潜在的)。如果缺失,默认为False。
- num_samples (int):
要抽取的样本数量。如果缺失,默认为1。
- obs (optional torch.Tensor):
观测值,或对于潜在变量为None。如果缺失,默认为None。
- prior (optional torch.distributions.Distribution):
(内部)用于在GuideMessenger中存储模型的先验分布(基于上游站点条件)。
- tmc (str):
是否在TraceTMC_ELBO中使用对角线或混合近似进行张量蒙特卡洛。
- was_observed (bool):
(内部)样本站点是否最初被观察到,在通过重加权唤醒睡眠或编译顺序重要性采样的推断背景下。
- enumerate: typing_extensions.Literal[sequential, parallel]¶
- obs: Optional[torch.Tensor]¶
- prior: TorchDistributionMixin¶
- tmc: typing_extensions.Literal[diagonal, mixture]¶
- class Message[source]¶
基础类:
typing_extensions.TypedDict
,Generic
[pyro.poutine.runtime._P
,pyro.poutine.runtime._T
]Pyro 内部用于效果处理的消息类型。
消息存储在跟踪对象中,例如:
trace.nodes["my_site_name"] # This is a Message.
- Keys:
- type (str):
消息类型,通常是字符串“sample”、“param”、“plate”或“markov”之一,但也可能是自定义的。
- name (str):
站点名称,通常用于命名样本或参数。
- fn (callable):
用于生成样本的分布或函数。
- is_observed (bool):
一个标志,用于指示该值是否被观察到。
- args (tuple):
分布或函数的位置参数。
- kwargs (dict):
分布或函数的关键字参数。
- value (torch.Tensor):
样本的值(无论是观察到的还是采样的)。
- scale (torch.Tensor):
对数概率的缩放因子。
- mask (bool torch.Tensor):
一个布尔值或张量,用于屏蔽对数概率。
- cond_indep_stack (tuple):
站点的条件独立性元数据的本地堆栈。 不可变。
- done (bool):
一个标志,用于指示消息是否已被处理。
- stop (bool):
一个标志,用于停止进一步处理消息。
- continuation (callable):
处理消息后调用的函数。
- infer (optional InferDict):
推理参数的字典。
- obs (torch.Tensor):
观察值。
- log_prob (torch.Tensor):
样本的对数概率。
- log_prob_sum (torch.Tensor):
对数概率的总和。
- unscaled_log_prob (torch.Tensor):
未缩放的日志概率。
- score_parts (pyro.distributions.ScoreParts):
一组分数部分。
- packed (Message):
一个打包的消息,用于枚举期间。
- args: Tuple¶
- cond_indep_stack: Tuple[CondIndepStackFrame, ...]¶
- fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]¶
- infer: Optional[pyro.poutine.runtime.InferDict]¶
- kwargs: Dict¶
- log_prob: torch.Tensor¶
- log_prob_sum: torch.Tensor¶
- mask: Optional[Union[bool, torch.Tensor]]¶
- obs: Optional[torch.Tensor]¶
- scale: Union[torch.Tensor, float]¶
- score_parts: ScoreParts¶
- unscaled_log_prob: torch.Tensor¶
- value: Optional[pyro.poutine.runtime._T]¶
- exception NonlocalExit(site: pyro.poutine.runtime.Message, *args, **kwargs)[source]¶
基础类:
Exception
用于从poutine执行中非局部退出的异常。
由 poutine.EscapeMessenger 使用以返回站点信息。
- apply_stack(initial_msg: pyro.poutine.runtime.Message) None [source]¶
根据以下方案在单个站点执行效果堆栈:
对于堆栈中从下到上的每个
Messenger
, 使用消息执行Messenger._process_message
; 如果消息字段“stop”为True,则停止; 否则,继续应用默认行为(
default_process_message
)以完成剩余的站点执行对于堆栈中的每个
Messenger
,从上到下执行_postprocess_message
,以使用站点结果更新消息和内部信使状态如果消息字段“continuation”不是
None
,则使用消息调用它
- Parameters
initial_msg (dict) – 跟踪站点的初始版本
- Returns
None
- default_process_message(msg: pyro.poutine.runtime.Message) None [source]¶
推理中处理消息的默认方法。
- Parameters
msg – 要处理的消息
- Returns
无
- effectful(fn: None = None, type: Optional[str] = None) Callable[[Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T]], Callable[[...], pyro.poutine.runtime._T]] [source]¶
- effectful(fn: Callable[[pyro.poutine.runtime._P], pyro.poutine.runtime._T] = None, type: Optional[str] = None) Callable[[...], pyro.poutine.runtime._T]
- Parameters
fn – 执行有效计算的函数或可调用对象
type (str) – 操作的类型标签,例如 “sample”
用于调用
apply_stack()
以应用任何活动效果的包装器。
- get_mask() Optional[Union[bool, torch.Tensor]] [source]¶
记录包含
poutine.mask
处理程序的效果。这对于在预测期间避免昂贵的
pyro.factor()
计算非常有用,当不需要计算对数密度时,例如:def model(): # ... if poutine.get_mask() is not False: log_density = my_expensive_computation() pyro.factor("foo", log_density) # ...
- Returns
掩码。
- Return type
None, bool, 或 torch.Tensor
- get_plates() Tuple[CondIndepStackFrame, ...] [source]¶
记录包含
pyro.plate
上下文的效果。- Returns
一个由
pyro.poutine.indep_messenger.CondIndepStackFrame
对象组成的元组。- Return type
实用工具¶
- all_escape(trace: Trace, msg: Message) bool [source]¶
- Parameters
trace – 部分追踪
msg – Pyro 原始站点的消息
- Returns
布尔决策值
实用函数,用于检查站点是否尚未在跟踪中。
由EscapeMessenger使用,用于决定是否在某个站点进行非本地退出。 用于近似积分变量以减少方差的子程序。
- discrete_escape(trace: Trace, msg: Message) bool [source]¶
- Parameters
trace – 部分追踪
msg – Pyro 原始站点的消息
- Returns
布尔决策值
实用函数,用于检查采样点是否是离散的且尚未在跟踪中。
由EscapeMessenger使用,用于决定是否在某个站点进行非本地退出。 用于整合离散变量以减少方差的子程序。
- enum_extend(trace: Trace, msg: Message, num_samples: Optional[int] = None) List[Trace] [source]¶
- Parameters
trace – 部分追踪
msg – Pyro 原始站点的消息
num_samples – 返回的最大扩展轨迹数。
- Returns
一个轨迹列表,输入轨迹的副本,每个副本多一个站点
实用函数,用于复制并扩展一个基于输入站点的轨迹,其值从输入站点分布的支撑中枚举。
用于精确推断和整合离散变量。
消息传递者¶
Messenger 对象包含由处理程序公开的效果的实现。 高级用户可以修改现有处理程序背后的 messenger 实现,或者编写新的 messenger, 这些 messenger 实现新的效果,并与库的其余部分正确组合。
信使¶
- class Messenger[source]¶
基础类:
object
上下文管理器类,用于修改行为 并向随机函数添加副作用 即包含Pyro原始语句的可调用对象。
这是基础的Messenger类。 它实现了所有Pyro原语的默认行为, 因此由随机函数fn引起的联合分布 与由
Messenger()(fn)
引起的联合分布相同。用于推理过程中传递消息的转换器类。 大多数推理操作都在此类的子类中实现。
- classmethod register(fn: Optional[Callable] = None, type: Optional[str] = None, post: Optional[bool] = None) Callable [source]¶
- Parameters
fn – 实现操作的函数
type (str) – 操作的名称 (也传递给
effectful()
)post (bool) – 如果为True,则将此操作用作后处理
动态地向效果添加操作。 对于生成库的包装器非常有用。
示例:
@SomeMessengerClass.register def some_function(msg) ...do_something... return msg
- classmethod unregister(fn: Optional[Callable] = None, type: Optional[str] = None) Optional[Callable] [source]¶
- Parameters
fn – 实现操作的函数
type (str) – 操作的名称 (也传递给
effectful()
)
动态地从效果中移除操作。 对于从库中移除包装器非常有用。
示例:
SomeMessengerClass.unregister(some_function, "name")
- block_messengers(predicate: Callable[[pyro.poutine.messenger.Messenger], bool]) Iterator[List[pyro.poutine.messenger.Messenger]] [source]¶
实验性的上下文管理器,用于临时从 _PYRO_STACK 中移除匹配的信使。注意这不会调用
.__exit__()
和.__enter__()
方法。这对于有选择性地阻止封闭处理程序非常有用。
- Parameters
predicate (callable) – 一个将信使实例映射到布尔值的谓词。 这将静音所有满足
bool(predicate(m)) is True
的信使m
。- Yields
被阻止的匹配信使列表。
BlockMessenger¶
- class BlockMessenger(hide_fn: Optional[Callable[[Message], Optional[bool]]] = None, expose_fn: Optional[Callable[[Message], Optional[bool]]] = None, hide_all: bool = True, expose_all: bool = False, hide: Optional[List[str]] = None, expose: Optional[List[str]] = None, hide_types: Optional[List[str]] = None, expose_types: Optional[List[str]] = None)[source]¶
基础类:
pyro.poutine.messenger.Messenger
此处理程序选择性地从外部世界隐藏Pyro原始站点。 默认行为:阻止所有内容。
如果满足以下至少一个条件,则站点被隐藏:
hide_fn(msg) is True
或(not expose_fn(msg)) is True
msg["name"] in hide
msg["type"] in hide_types
msg["name"] not in expose and msg["type"] not in expose_types
hide
,hide_types
, 和expose_types
都是None
例如,假设随机函数 fn 有两个样本点“a”和“b”。 那么在
BlockMessenger(fn, hide=["a"])
之外的任何效果 都不会应用于点“a”,而只会看到点“b”:>>> def fn(): ... a = pyro.sample("a", dist.Normal(0., 1.)) ... return pyro.sample("b", dist.Normal(a, 1.)) >>> fn_inner = pyro.poutine.trace(fn) >>> fn_outer = pyro.poutine.trace(pyro.poutine.block(fn_inner, hide=["a"])) >>> trace_inner = fn_inner.get_trace() >>> trace_outer = fn_outer.get_trace() >>> "a" in trace_inner True >>> "a" in trace_outer False >>> "b" in trace_inner True >>> "b" in trace_outer True
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
hide_fn – 一个函数,它接受一个站点并返回True以隐藏该站点,或返回False/None以暴露它。如果指定了此函数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。
expose_fn – 一个函数,接受一个站点并返回True以暴露该站点,或返回False/None以隐藏它。如果指定了此参数,则忽略所有其他参数。只能指定hide_fn或expose_fn中的一个,不能同时指定两者。
hide_all (bool) – 隐藏所有站点
expose_all (bool) – 正常暴露所有站点
隐藏 (list) – 要隐藏的站点名称列表
expose (list) – 要暴露的站点名称列表,其他站点将被隐藏
hide_types (list) – 要隐藏的站点类型列表
expose_types (list) – 要暴露的站点类型列表,其他所有类型将被隐藏
- Returns
使用
BlockMessenger
装饰的随机函数
广播信使¶
- class BroadcastMessenger[source]¶
基础类:
pyro.poutine.messenger.Messenger
当在单个或嵌套的plate上下文中时,自动广播随机函数在采样点的批次形状。现有的batch_shape必须与安装在cond_indep_stack中的
plate
上下文的大小可广播。注意下面的model_automatic_broadcast如何自动扩展分布批处理形状。这使得将Pyro模型模块化变得容易,因为子组件对包装的
plate
上下文是无感知的。>>> def model_broadcast_by_hand(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.ones(3) * 0.5) ... .expand_by(100)) ... assert sample.shape == torch.Size((100, 3)) ... return sample
>>> @poutine.broadcast ... def model_automatic_broadcast(): ... with IndepMessenger("batch", 100, dim=-2): ... with IndepMessenger("components", 3, dim=-1): ... sample = pyro.sample("sample", dist.Bernoulli(torch.tensor(0.5))) ... assert sample.shape == torch.Size((100, 3)) ... return sample
折叠消息器¶
- class CollapseMessenger(*args: Any, **kwargs: Any)[source]¶
基础类:
pyro.poutine.trace_messenger.TraceMessenger
实验性功能:通过懒采样并尝试使用共轭关系来折叠上下文中的所有站点。如果不知道共轭关系,这将失败。使用采样站点结果的代码必须编写为接受Funsors而不是Tensors。这需要安装
funsor
。警告
这与自动猜测
max_plate_nesting
不兼容。如果在折叠的上下文中出现任何板块,您应该手动向您的推理算法声明max_plate_nesting
(例如Trace_ELBO(max_plate_nesting=1)
)。
条件信使¶
- class ConditionMessenger(data: Union[Dict[str, torch.Tensor], pyro.poutine.trace_struct.Trace])[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个包含一些样本语句的随机函数和一个名称的观察字典, 将那些名称处的样本语句更改为具有这些值的观察。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要观察站点 z 的值,我们可以写
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
这相当于在model中将obs=value作为关键字参数添加到pyro.sample(“z”, …)中。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
data – 一个字典或一个
Trace
- Returns
使用
ConditionMessenger
装饰的随机函数
DoMessenger¶
- class DoMessenger(data: Dict[str, Union[torch.Tensor, numbers.Number]])[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个带有一些样本语句的随机函数和一个名称值的字典, 将这些站点的返回值设置为这些值, 就像它们被硬编码为这些值一样, 并引入具有相同名称的新样本站点, 这些站点的值不会传播。
与
condition()
自由组合,以表示潜在结果的反事实分布。 有关更多细节和理论,请参见单世界干预图[1]。考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要为站点 z 干预一个值,我们可以写
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
这相当于将z = pyro.sample(“z”, …)替换为 z = torch.tensor(1.) 并引入一个新的采样点pyro.sample(“z”, …),其值在其他地方未使用。
参考文献
- [1] Single World Intervention Graphs: A Primer,
托马斯·理查森,詹姆斯·罗宾斯
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
data – 一个将样本站点名称映射到干预措施的
dict
- Returns
使用
DoMessenger
装饰的随机函数
枚举消息传递器¶
- class EnumMessenger(first_available_dim: Optional[int] = None)[source]¶
基础类:
pyro.poutine.messenger.Messenger
并行枚举标记为
infer={"enumerate": "parallel"}
的离散样本点。- Parameters
first_available_dim (int) – 第一个可用于并行枚举的张量维度(从右开始计数)。这个维度和所有左侧的维度可能会被Pyro内部使用。这应该是一个负整数或None。
- enumerate_site(msg: pyro.poutine.runtime.Message) torch.Tensor [source]¶
EscapeMessenger¶
- class EscapeMessenger(escape_fn: Callable[[pyro.poutine.runtime.Message], bool])[source]¶
基础类:
pyro.poutine.messenger.Messenger
通过引发 util.NonlocalExit 异常来进行非局部退出的 Messenger
独立消息传递器¶
- class IndepMessenger(name: str, size: int, dim: Optional[int] = None, device: Optional[str] = None)[源代码]¶
基础类:
pyro.poutine.messenger.Messenger
这个信使跟踪由嵌套的
plate
上下文声明的独立性信息堆栈。这些信息存储在cond_indep_stack
中,在每个采样/观察站点供TraceMessenger
使用。示例:
x_axis = IndepMessenger('outer', 320, dim=-1) y_axis = IndepMessenger('inner', 200, dim=-2) with x_axis: x_noise = sample("x_noise", dist.Normal(loc, scale).expand_by([320])) with y_axis: y_noise = sample("y_noise", dist.Normal(loc, scale).expand_by([200, 1])) with x_axis, y_axis: xy_noise = sample("xy_noise", dist.Normal(loc, scale).expand_by([200, 320]))
- property indices: torch.Tensor¶
InferConfigMessenger¶
- class InferConfigMessenger(config_fn: Callable[[Message], InferDict])[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个包含Pyro原始调用的可调用对象fn和一个接受跟踪站点并返回字典的可调用对象config_fn,将样本站点的infer kwarg的值更新为config_fn(site)。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
config_fn – 一个可调用对象,接收一个站点并返回一个推断字典
- Returns
使用
InferConfigMessenger
装饰的随机函数
LiftMessenger¶
- class LiftMessenger(prior: Union[Callable, pyro.distributions.distribution.Distribution, Dict[str, Union[pyro.distributions.distribution.Distribution, Callable]]])[源代码]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个带有参数调用的随机函数和一个先验分布,创建一个随机函数,其中所有参数调用都被从先验中采样所替代。先验应该是一个可调用对象或一个名称到可调用对象的字典。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2 >>> lifted_model = pyro.poutine.lift(model, prior={"s": dist.Exponential(0.3)})
lift
使param
语句的行为类似于sample
语句 使用prior
中的分布。在这个例子中,站点 s 现在将表现得 好像它被替换为s = pyro.sample("s", dist.Exponential(0.3))
:>>> tr = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> tr.nodes["s"]["type"] == "sample" True >>> tr2 = pyro.poutine.trace(lifted_model).get_trace(0.0) >>> bool((tr2.nodes["s"]["value"] == tr.nodes["s"]["value"]).all()) False
- Parameters
fn – 其参数将被提升为随机值的函数
prior – 以Distribution形式或随机函数字典形式存在的先验函数
- Returns
fn
使用LiftMessenger
装饰
马尔科夫信使¶
MaskMessenger¶
- class MaskMessenger(mask: Union[bool, torch.BoolTensor])[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个带有一些批量样本语句和掩码张量的随机函数,按元素掩码掉一些样本语句。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
mask (torch.BoolTensor) – 一个
{0,1}
值的掩码张量 (1 包含一个站点,0 排除一个站点)
- Returns
使用
MaskMessenger
装饰的随机函数
PlateMessenger¶
- class PlateMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
基础类:
pyro.poutine.subsample_messenger.SubsampleMessenger
广播功能的瑞士军刀: 结合了形状推断、独立性注释和子采样
- block_plate(name: Optional[str] = None, dim: Optional[int] = None, *, strict: bool = True) Iterator[None] [source]¶
实验性的上下文管理器,用于临时阻止单个封闭的板块。
这对于采样辅助变量或在需要时延迟采样全局变量非常有用。例如,以下模型是等效的:
示例:
def model_1(data): loc = pyro.sample("loc", dist.Normal(0, 1)) with pyro.plate("data", len(data)): with block_plate("data"): scale = pyro.sample("scale", dist.LogNormal(0, 1)) pyro.sample("x", dist.Normal(loc, scale)) def model_2(data): loc = pyro.sample("loc", dist.Normal(0, 1)) scale = pyro.sample("scale", dist.LogNormal(0, 1)) with pyro.plate("data", len(data)): pyro.sample("x", dist.Normal(loc, scale))
可重入信使¶
ReparamMessenger¶
- class ReparamHandler(msngr, fn: Callable[[pyro.poutine.reparam_messenger._P], pyro.poutine.reparam_messenger._T])[source]¶
基础类:
Generic
[pyro.poutine.reparam_messenger._P
,pyro.poutine.reparam_messenger._T
]重新参数化程序。
- class ReparamMessenger(config: Union[Dict[str, Reparam], Callable[[Message], Optional[Reparam]]])[源代码]¶
基础类:
pyro.poutine.messenger.Messenger
将每个受影响的采样点重新参数化为一个或多个辅助采样点,然后进行确定性变换 [1]。
要指定重新参数化器,请将
config
字典或可调用对象传递给构造函数。有关可用的重新参数化器,请参见pyro.infer.reparam
模块。请注意,一些重新参数化器可以检查它们影响的函数的
*args,**kwargs
输入;这些重新参数化器需要使用poutine.reparam
作为装饰器而不是上下文管理器。- [1] Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019)
“概率程序的自动重新参数化” https://arxiv.org/pdf/1906.03028.pdf
- Parameters
config (dict 或 callable) – 配置,可以是一个将站点名称映射到
Reparameterizer
的字典,或者是一个将站点映射到Reparam
或 None 的函数。有关内置的配置策略,请参见pyro.infer.reparam.strategies
。
重放消息器¶
- class ReplayMessenger(trace: Optional[Trace] = None, params: Optional[Dict[str, torch.Tensor]] = None)[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个包含Pyro原始调用的可调用对象, 返回一个运行原始可调用对象的可调用对象,在新跟踪中重用跟踪中的值
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
replay
使得sample
语句表现得好像它们在跟踪中的相应位置采样了值一样:>>> old_trace = pyro.poutine.trace(model).get_trace(1.0) >>> replayed_model = pyro.poutine.replay(model, trace=old_trace) >>> bool(replayed_model(0.0) == old_trace.nodes["_RETURN"]["value"]) True
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
trace – 一个用于回放的
Trace
数据结构params – 参数字典,包含参数名称和约束值,用于在函数中重放
- Returns
一个用
ReplayMessenger
装饰的随机函数
ScaleMessenger¶
- class ScaleMessenger(scale: Union[float, torch.Tensor])[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个包含一些样本语句和正比例因子的随机函数,缩放函数中所有样本和观察点的分数。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scale
以乘法方式缩放样本站点的对数概率:>>> scaled_model = pyro.poutine.scale(model, scale=0.5) >>> scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0) >>> unscaled_tr = pyro.poutine.trace(model).get_trace(0.0) >>> bool((scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()) True
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
scale – 一个正的比例因子
- Returns
使用
ScaleMessenger
装饰的随机函数
SeedMessenger¶
- class SeedMessenger(rng_seed: int)[source]¶
基础类:
pyro.poutine.messenger.Messenger
处理程序通过设置种子将随机数生成器设置为预定义状态。这与在调用fn之前调用
pyro.set_rng_seed()
相同。此处理程序对标准Pyro后端上的原始语句没有额外影响,但它可能会在其他后端(例如NumPy后端)中拦截pyro.sample
调用。- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)。
rng_seed (int) – 随机数生成器种子。
子样本信使¶
- class SubsampleMessenger(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]¶
基础类:
pyro.poutine.indep_messenger.IndepMessenger
IndepMessenger 的扩展,包括子采样。
替代信使¶
- class SubstituteMessenger(data: Dict[str, torch.Tensor])[source]¶
基础类:
pyro.poutine.messenger.Messenger
给定一个带有参数调用的随机函数和一组参数值,创建一个随机函数,其中所有参数调用都被替换为固定值。数据应该是一个名称到值的字典。考虑以下Pyro程序:
>>> def model(x): ... a = pyro.param("a", torch.tensor(0.5)) ... x = pyro.sample("x", dist.Bernoulli(probs=a)) ... return x >>> substituted_model = pyro.poutine.substitute(model, data={"a": torch.tensor(0.3)})
在这个例子中,站点 a 现在将具有值 torch.tensor(0.3)。 :param data: 由站点名称键控的值的字典。 :returns:
fn
装饰有SubstituteMessenger
TraceMessenger¶
- class TraceHandler(msngr: pyro.poutine.trace_messenger.TraceMessenger, fn: Callable[[pyro.poutine.trace_messenger._P], pyro.poutine.trace_messenger._T])[源代码]¶
基础类:
Generic
[pyro.poutine.trace_messenger._P
,pyro.poutine.trace_messenger._T
]执行跟踪程序。
TraceHandler 记录每个 Pyro 原语的输入和输出,并将它们存储为 Trace() 中的一个站点。 理论上,这应该足以提供每个推理算法所需的信息 (以及 Variables 中的隐式计算图?)
我们也可以将其用于可视化。
- get_trace(*args, **kwargs) pyro.poutine.trace_struct.Trace [源代码]¶
- Returns
数据结构
- Return type
用于非常常见的使用场景的辅助方法。 调用此poutine并返回其跟踪,而不是函数的返回值。
- property trace: pyro.poutine.trace_struct.Trace¶
- class TraceMessenger(graph_type: Optional[Literal['flat', 'dense']] = None, param_only: Optional[bool] = None)[源代码]¶
基础类:
pyro.poutine.messenger.Messenger
返回一个记录原始调用输入输出及其依赖关系的处理程序。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
我们可以使用
trace
记录其执行过程,并使用生成的数据结构来计算执行过程中所有采样点的对数联合概率或提取所有参数。>>> trace = pyro.poutine.trace(model).get_trace(0.0) >>> logp = trace.log_prob_sum() >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
graph_type – 指定要构建的图形类型的字符串
param_only – 如果为true,仅记录参数而不记录样本
- Returns
使用
TraceMessenger
装饰的随机函数
- get_trace() pyro.poutine.trace_struct.Trace [source]¶
- Returns
数据结构
- Return type
用于非常常见用例的辅助方法。 返回
self.trace
的浅拷贝。
- identify_dense_edges(trace: pyro.poutine.trace_struct.Trace) None [source]¶
通过根据每个站点存储的cond_indep_stack信息添加所有边来就地修改跟踪。
无条件信使¶
- class UnconditionMessenger[source]¶
基础类:
pyro.poutine.messenger.Messenger
Messenger 强制从它们的分布中采样观察节点的值,忽略观察结果。
指南信使¶
- class GuideMessenger(model: Callable)[source]¶
基础类:
pyro.poutine.trace_messenger.TraceMessenger
,abc.ABC
基于效果的指南的抽象基类。
派生类必须实现
get_posterior()
方法。- property model: Callable¶
- __call__(*args, **kwargs) Dict[str, torch.Tensor] [source]¶
从指南中抽取后验样本,并针对这些样本重放模型。
- Returns
一个将样本站点名称映射到样本值的字典。 这包括潜在值、确定值和观测值。
- Return type
- abstract get_posterior(name: str, prior: TorchDistributionMixin) Union[TorchDistributionMixin, torch.Tensor] [source]¶
抽象方法,用于计算后验分布或在给定上游后验样本的条件下采样后验值。
实现可以在此函数内部使用
pyro.param
和pyro.sample
,但pyro.sample
语句应设置infer={"is_auxiliary": True"}
。实现可能会访问更多信息以进行计算:
value = self.upstream_value(name)
is the value of an upstream样本或确定性站点。
self.trace
是上游站点的跟踪信息,可能对其他信息有用,例如self.trace.nodes["my_site"]["fn"]
或self.trace.nodes["my_site"]["cond_indep_stack"]
。args, kwargs = self.args_kwargs
are the inputs to the model, and可能对摊销有用。
- Parameters
name (str) – 要采样的样本站点的名称。
prior (Distribution) – 该样本点的先验分布 (以后验的上游样本为条件)。
- Returns
后验分布或来自后验分布的样本。
- Return type
- upstream_value(name: str) Optional[torch.Tensor] [source]¶
用于
get_posterior()
。- Returns
上游样本或确定性站点的值
- Return type
- get_traces() Tuple[pyro.poutine.trace_struct.Trace, pyro.poutine.trace_struct.Trace] [source]¶
这可以在运行
__call__()
之后调用,以提取一对轨迹。与生成一对轨迹的跟踪重放模式不同,
GuideMessenger
交错进行模型和指导计算,因此 只需要一次guide(*args, **kwargs)
调用即可创建两个 轨迹。此函数仅从该指导的.trace
属性中提取相关信息。- Returns
一对
(model_trace, guide_trace)
- Return type