基于Funsor的Pyro¶
基本元素¶
- clear_param_store() None[source]¶
清除全局的
ParamStoreDict。这在REPL环境中尤其有用。我们建议在每次训练循环之前调用此函数(以避免从过去的模型中泄露参数),并在每个单元测试之前调用(以避免在测试之间泄露参数)。
- condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Union[pyro.poutine.condition_messenger.ConditionMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]][source]¶
ConditionMessenger的便捷包装器给定一个包含一些样本语句的随机函数和一个名称的观察字典, 将那些名称处的样本语句更改为具有这些值的观察。
考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要观察站点 z 的值,我们可以写
>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})
这相当于在model中将obs=value作为关键字参数添加到pyro.sample(“z”, …)中。
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
data – 一个字典或一个
Trace
- Returns
使用
ConditionMessenger装饰的随机函数
- deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor[source]¶
确定性语句,用于向跟踪中添加一个名为name且值为value的
Delta站点。当我们想要记录完全由其父节点确定的值时,这非常有用。例如:x = pyro.sample("x", dist.Normal(0, 1)) x2 = pyro.deterministic("x2", x ** 2)
注意
该站点不影响模型密度。目前这转换为一个
sample()语句,但未来可能会改变。- Parameters
name (str) – 站点的名称。
value (torch.Tensor) – 站点的值。
event_dim (int) – 可选的事件维度,默认为 value.ndim。
- do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Union[pyro.poutine.do_messenger.DoMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]][source]¶
DoMessenger的便捷包装器给定一个带有一些样本语句的随机函数和一个名称值的字典, 将这些站点的返回值设置为这些值, 就像它们被硬编码为这些值一样, 并引入具有相同名称的新样本站点, 这些站点的值不会传播。
与
condition()自由组合,以表示潜在结果的反事实分布。有关更多细节和理论,请参见单世界干预图[1]。考虑以下Pyro程序:
>>> def model(x): ... s = pyro.param("s", torch.tensor(0.5)) ... z = pyro.sample("z", dist.Normal(x, s)) ... return z ** 2
要为站点 z 干预一个值,我们可以写
>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})
这相当于将z = pyro.sample(“z”, …)替换为 z = torch.tensor(1.) 并引入一个新的采样点pyro.sample(“z”, …),其值在其他地方未使用。
参考文献
- [1] Single World Intervention Graphs: A Primer,
托马斯·理查森,詹姆斯·罗宾斯
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
data – 一个将样本站点名称映射到干预措施的
dict
- Returns
使用
DoMessenger装饰的随机函数
- enable_validation(is_validate: bool = True) None[source]¶
启用或禁用Pyro中的验证检查。验证检查提供了有用的警告和错误,例如NaN检查、验证分布参数和支持值、检测ELBO和MCMC的错误使用。由于其中一些检查可能很昂贵,您可能希望禁用成熟模型的验证以加快推理速度。
默认行为模仿了Python的
assert语句:默认情况下验证是开启的,但如果Python以优化模式运行(通过python -O),则会被禁用。同样地,默认行为取决于Python的全局__debug__值,通过pyro.enable_validation(__debug__)来实现。在JIT编译期间,所有支持PyTorch JIT的推理算法暂时禁用验证。我们建议使用非JIT编译的推理算法开发模型以便于调试,然后在模型正确后选择性地切换到JIT编译的推理算法。
- Parameters
is_validate (bool) – (可选;默认为 True) 是否启用验证检查。
- factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None[source]¶
因子语句,用于向概率模型添加任意对数概率因子。
警告
在指南中使用因子语句时,您需要指定因子语句是来自完全重新参数化的采样(例如,重新参数化变量变换的雅可比行列式)还是来自非重新参数化的采样(例如,离散样本)。对于完全重新参数化的情况,设置
has_rsample=True;对于非重新参数化的情况,设置has_rsample=False。这仅在指南中需要,而不是在模型中。- Parameters
name (str) – 简单样本的名称
log_factor (torch.Tensor) – 一个可能是批处理的对数概率因子。
has_rsample (bool) – 是否
log_factor来自完全重新参数化的分布。在模型中使用时默认为False,但在指南中使用时必须指定。
- get_param_store() pyro.params.param_store.ParamStoreDict[source]¶
返回全局的
ParamStoreDict。
- markov(fn=None, history=1, keep=False)[源代码]¶
MarkovMessenger的便捷包装器用于与Pyro的位置批处理维度一致的funsors之间转换的处理程序。
- module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module[源代码]¶
将
torch.nn.Module的所有参数注册到Pyro的param_store中。结合ParamStoreDict的save()和load()功能,这允许用户保存和加载模块。注意
考虑使用
PyroModule,这是一个较新的替代方案,相比pyro.module(),它在以下方面有更好的支持:即时编译(jitting)、在C++中服务以及将参数转换为随机变量。详情请参阅模块教程。- Parameters
name (str) – 模块的名称
nn_module (torch.nn.Module) – 要注册到Pyro的模块
update_module_params – 确定是否用ParamStore中找到的值覆盖PyTorch模块中的参数(如果有的话)。默认为False
- Returns
torch.nn.Module
- param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]¶
将变量保存为参数存储中的参数。 要与参数存储交互或写入磁盘, 请参阅参数。
- Parameters
name (str) – 参数的名称
init_tensor (torch.Tensor 或 可调用) – 初始张量或返回张量的惰性可调用对象。 对于大型张量,编写例如
lambda: torch.randn(100000)可能更便宜,它只会在初始语句中求值。约束 (torch.distributions.constraints.Constraint) – torch 约束,默认为
constraints.real。event_dim (int) – (可选)与批处理无关的最右侧维度的数量。此维度左侧的维度将被视为批处理维度;如果参数语句位于子采样的plate中,则参数的相应批处理维度将相应地进行子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。
- Returns
一个受约束的参数。可以通过
pyro.param(...).unconstrained()访问底层未受约束的参数,其中.unconstrained是一个弱引用属性。- Return type
- random_module(name, nn_module, prior, *args, **kwargs)[source]¶
警告
random_module 原语已被弃用,并将在未来的版本中移除。请改用
PyroModule来从torch.nn.Module实例创建贝叶斯模块。有关示例,请参阅 贝叶斯回归教程。已弃用 在模块 nn_module 的参数上放置先验。 返回一个可调用的 nn.Module 分布,调用时 返回一个采样的 nn.Module。
- Parameters
name (str) – pyro模块的名称
nn_module (torch.nn.Module) – 要注册到pyro的模块
prior – pyro 分布、随机函数或带有参数名称作为键和相应分布/随机函数作为值的 python 字典。
- Returns
一个可调用的函数,返回一个采样模块
- sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor[source]¶
调用随机函数
fn,根据name和封闭上下文(例如推理算法)产生额外的副作用。有关讨论,请参见 Introduction to Pyro。
- subsample(data: torch.Tensor, event_dim: int) torch.Tensor[source]¶
子采样语句,用于基于封闭的
plate对数据张量进行子采样。这通常在自动通过
plate进行子采样时,通过传递subsample或subsample_size关键字参数来调用model()的参数。例如,以下是等价的:# Version 1. using indexing def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind: data = data[ind] # ... # Version 2. using pyro.subsample() def model(data): with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()): data = pyro.subsample(data, event_dim=0) # ...
- vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]¶
VectorizedMarkovMessenger的便捷封装为变量设计的马尔可夫链构造,旨在使用并行扫描算法高效消除马尔可夫维度。在允许的情况下,
vectorized_markov可以与markov互换使用。for循环生成
int和一维torch.Tensor索引:(0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))。int索引用于初始化马尔可夫链,而torch.Tensor索引 用于构建向量化的转移概率,以便通过并行扫描算法进行高效消除。当
history==0时,vectorized_markov的行为类似于plate。在for循环运行后,Markov变量被识别,然后
step信息被构建并添加到跟踪中。step通知推理算法哪些变量属于Markov链。data = torch.ones(3, dtype=torch.float) def model(data, vectorized=True): init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(3,)) markov_chain = \ pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \ else pyro.markov(range(len(data))) for i in markov_chain: x_curr = pyro.sample("x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev]), pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr # trace.nodes["time"]["value"] # frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')}) # # pyro.vectorized_markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_slice(1, 3, None) dist 3 1 1 1 1 2 | # value 3 1 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 2 | # y_slice(1, 3, None) dist 3 1 1 1 1 1 2 | # value 2 | # log_prob 3 1 1 1 1 1 2 | # # pyro.markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_1 dist 3 1 1 1 1 | # value 3 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_1 dist 3 1 1 1 1 1 | # value | # log_prob 3 1 1 1 1 1 | # x_2 dist 3 1 1 1 1 1 | # value 3 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_2 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 |
警告
这仅在每个分支只有一个马尔可夫维度时有效。
- Parameters
- Returns
返回
int和一维torch.Tensor索引:(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))。
效果处理器¶
- enum(fn=None, first_available_dim=None)[source]¶
EnumMessenger的便捷包装器此版本的
EnumMessenger使用to_data()为每个离散样本站点分配一个新的枚举维度。
- markov(fn=None, history=1, keep=False)[源代码]¶
MarkovMessenger的便捷包装器用于与Pyro的位置批处理维度一致的funsors之间转换的处理程序。
- named(fn=None, first_available_dim=None)[source]¶
NamedMessenger的便捷包装器用于
to_funsor()和to_data()原语的基础效果处理类。 任何在内部调用这些原语或包装调用这些原语的代码的效果处理程序都应继承自NamedMessenger。这种设计确保了在处理器退出时全局名称-维度映射会被重置,而不是可能持续到整个程序终止。
- plate(fn=None, name=None, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None, device=None)[source]¶
PlateMessenger的便捷包装器将新的
IndepMessenger实现与现有的pyro.poutine.BroadcastMessenger结合。最终应该可以替代pyro.plate。
- replay(fn=None, trace=None, params=None)[source]¶
ReplayMessenger的便捷包装器这个版本的
ReplayMessenger几乎与原始版本相同, 除了它在重放的funsor值上调用to_data()。 这可能会导致不同的解包形状,但应该会产生正确的分配。
- trace(fn=None, graph_type=None, param_only=None, pack_online=True)[source]¶
TraceMessenger的便捷包装器设置
pack_online=True会在线打包而不是事后打包,一旦所有分布和值可用,立即将它们转换为 Funsors。设置
pack_online=False计算执行后打包所需的信息。 每个样本站点都带有一个dim_to_name字典的注释, 可以直接传递给to_funsor()。
- vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]¶
VectorizedMarkovMessenger的便捷封装为变量设计的马尔可夫链构造,旨在使用并行扫描算法高效消除马尔可夫维度。在允许的情况下,
vectorized_markov可以与markov互换使用。for循环生成
int和一维torch.Tensor索引:(0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))。int索引用于初始化马尔可夫链,而torch.Tensor索引 用于构建向量化的转移概率,以便通过并行扫描算法进行高效消除。当
history==0时,vectorized_markov的行为类似于plate。在for循环运行后,Markov变量被识别,然后
step信息被构建并添加到跟踪中。step通知推理算法哪些变量属于Markov链。data = torch.ones(3, dtype=torch.float) def model(data, vectorized=True): init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex) trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex) locs = pyro.param("locs", lambda: torch.rand(3,)) markov_chain = \ pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \ else pyro.markov(range(len(data))) for i in markov_chain: x_curr = pyro.sample("x_{}".format(i), dist.Categorical( init if isinstance(i, int) and i < 1 else trans[x_prev]), pyro.sample("y_{}".format(i), dist.Normal(Vindex(locs)[..., x_curr], 1.), obs=data[i]) x_prev = x_curr # trace.nodes["time"]["value"] # frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')}) # # pyro.vectorized_markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_slice(1, 3, None) dist 3 1 1 1 1 2 | # value 3 1 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 2 | # y_slice(1, 3, None) dist 3 1 1 1 1 1 2 | # value 2 | # log_prob 3 1 1 1 1 1 2 | # # pyro.markov trace # ... # Sample Sites: # locs dist | 3 # value | 3 # log_prob | # x_0 dist | # value 3 1 1 1 1 | # log_prob 3 1 1 1 1 | # y_0 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 | # x_1 dist 3 1 1 1 1 | # value 3 1 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_1 dist 3 1 1 1 1 1 | # value | # log_prob 3 1 1 1 1 1 | # x_2 dist 3 1 1 1 1 1 | # value 3 1 1 1 1 | # log_prob 3 3 1 1 1 1 | # y_2 dist 3 1 1 1 1 | # value | # log_prob 3 1 1 1 1 |
警告
这仅在每个分支只有一个马尔可夫维度时有效。
- Parameters
- Returns
返回
int和一维torch.Tensor索引:(0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))。
- class NamedMessenger(first_available_dim=None)[源代码]¶
基础类:
pyro.poutine.reentrant_messenger.ReentrantMessenger用于
to_funsor()和to_data()原语的基础效果处理程序类。 任何在内部调用这些原语或包装执行这些操作的代码的效果处理程序都应继承自NamedMessenger。这种设计确保了在处理器退出时全局名称-维度映射会被重置,而不是可能持续到整个程序终止。
- class MarkovMessenger(history=1, keep=False)[源代码]¶
基础类:
pyro.contrib.funsor.handlers.named_messenger.NamedMessenger用于与Pyro的位置批处理维度一致的funsors之间转换的处理程序。
- class GlobalNamedMessenger(first_available_dim=None)[source]¶
基础类:
pyro.contrib.funsor.handlers.named_messenger.NamedMessenger任何使用
to_funsor()和to_data()原语来分配DimType.GLOBAL或DimType.VISIBLE维度的新效果处理程序的基类。作为不应被
MarkovMessenger回收的维度的手动“作用域”:全局维度将被视为活动状态,直到最初分配它们的GlobalNamedMessenger最内层退出。
- class StackFrame(name_to_dim, dim_to_name, history=1, keep=False)[source]¶
基础类:
object整数位置维度和名称之间的一致双向映射。 可以像字典一样查询(
value = frame[key],frame[key] = value)。
- class DimStack[source]¶
基础类:
object用于跟踪名称和维度之间映射的全局状态。
替换了板
_DimAllocator, 枚举_EnumAllocator,MarkovMessenger中的stack,EnumMessenger中的_param_dims和_value_dims,以及msg['infer']中的dim_to_symbol- MAX_DIM = -25¶
- DEFAULT_FIRST_DIM = -5¶
- property global_frame¶
- property local_frame¶
- property current_write_env¶
- property current_read_env¶
收集所有必要的框架以计算全名 <–> 维度映射,并在计算中的任何点解释Funsor输入或批量形状。
推理算法¶
- class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
-
- loss_and_grads(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss_and_grads()
- class Jit_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
- class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
- class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[源代码]¶
基础类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO,pyro.contrib.funsor.infer.trace_elbo.Trace_ELBO
- class TraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基础类:
pyro.contrib.funsor.infer.elbo.ELBO- differentiable_loss(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()
- class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基础类:
pyro.contrib.funsor.infer.elbo.ELBO- differentiable_loss(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()
- class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基础类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO,pyro.contrib.funsor.infer.traceenum_elbo.TraceEnum_ELBO
- class JitTraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基础类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO,pyro.contrib.funsor.infer.traceenum_elbo.TraceMarkovEnum_ELBO
- class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]¶
基础类:
pyro.contrib.funsor.infer.elbo.ELBO- differentiable_loss(model, guide, *args, **kwargs)[source]¶
参见
pyro.infer.tracetmc_elbo.TraceTMC_ELBO.differentiable_loss()
- class JitTraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[源代码]¶
基础类:
pyro.contrib.funsor.infer.elbo.Jit_ELBO,pyro.contrib.funsor.infer.tracetmc_elbo.TraceTMC_ELBO