基本元素

get_param_store() pyro.params.param_store.ParamStoreDict[source]

返回全局的ParamStoreDict

clear_param_store() None[source]

清除全局的ParamStoreDict

这在REPL环境中尤其有用。我们建议在每次训练循环之前调用此函数(以避免从过去的模型中泄露参数),并在每个单元测试之前调用(以避免在测试之间泄露参数)。

param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

将变量保存为参数存储中的参数。 要与参数存储交互或写入磁盘, 请参阅参数

Parameters
  • name (str) – 参数的名称

  • init_tensor (torch.Tensor可调用) – 初始张量或返回张量的惰性可调用对象。 对于大型张量,编写例如 lambda: torch.randn(100000) 可能更便宜,它只会在初始语句中求值。

  • 约束 (torch.distributions.constraints.Constraint) – torch 约束,默认为 constraints.real

  • event_dim (int) – (可选)与批处理无关的最右侧维度的数量。此维度左侧的维度将被视为批处理维度;如果参数语句位于子采样的plate中,则参数的相应批处理维度将相应地进行子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。

Returns

一个受约束的参数。可以通过pyro.param(...).unconstrained()访问底层未受约束的参数,其中.unconstrained是一个弱引用属性。

Return type

torch.Tensor

sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor[source]

调用随机函数 fn,根据 name 和封闭上下文(例如推理算法)产生额外的副作用。有关讨论,请参见 Introduction to Pyro

Parameters
  • name – 样本的名称

  • fn – 分布类或函数

  • obs – 观测数据(可选;仅在推理上下文中使用)可选地在kwargs中指定

  • obs_mask (boolTensor) – 可选的布尔张量掩码,形状与 fn.batch_shape 可广播。如果提供,具有 mask=True 的事件将根据 obs 进行条件处理,其余事件将通过采样进行填补。这会引入一个名为 name + "_unobserved" 的潜在样本站点,应由指南使用。

  • infer (dict) – 可选的推理参数字典,通过kwargs指定。详情请参阅推理文档。

Returns

样本

factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None[source]

因子语句,用于向概率模型添加任意对数概率因子。

警告

在指南中使用因子语句时,您需要指定因子语句是来自完全重新参数化的采样(例如,重新参数化变量变换的雅可比行列式)还是来自非重新参数化的采样(例如,离散样本)。对于完全重新参数化的情况,设置has_rsample=True;对于非重新参数化的情况,设置has_rsample=False。这仅在指南中需要,而不是在模型中。

Parameters
  • name (str) – 简单样本的名称

  • log_factor (torch.Tensor) – 一个可能是批处理的对数概率因子。

  • has_rsample (bool) – 是否log_factor来自完全重新参数化的分布。在模型中使用时默认为False,但在指南中使用时必须指定。

deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor[source]

确定性语句,用于向跟踪中添加一个名为name且值为valueDelta站点。当我们想要记录完全由其父节点确定的值时,这非常有用。例如:

x = pyro.sample("x", dist.Normal(0, 1))
x2 = pyro.deterministic("x2", x ** 2)

注意

该站点不影响模型密度。目前这转换为一个sample()语句,但未来可能会改变。

Parameters
  • name (str) – 站点的名称。

  • value (torch.Tensor) – 站点的值。

  • event_dim (int) – 可选的事件维度,默认为 value.ndim

subsample(data: torch.Tensor, event_dim: int) torch.Tensor[source]

子采样语句,用于基于封闭的plate对数据张量进行子采样。

这通常在自动执行子采样时对model()的参数调用,通过传递subsamplesubsample_size关键字参数由plate自动执行。例如,以下是等价的:

# Version 1. using indexing
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using pyro.subsample()
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = pyro.subsample(data, event_dim=0)
        # ...
Parameters
  • data (Tensor) – 一个批处理数据的张量。

  • event_dim (int) – 数据张量的事件维度。左侧的维度被视为批次维度。

Returns

data 的子采样版本

Return type

Tensor

class plate(name: str, size: Optional[int] = None, subsample_size: Optional[int] = None, subsample: Optional[torch.Tensor] = None, dim: Optional[int] = None, use_cuda: Optional[bool] = None, device: Optional[str] = None)[source]

基础类:pyro.poutine.plate_messenger.PlateMessenger

构建条件独立的变量序列。

plate 可以顺序使用作为生成器,也可以并行使用作为上下文管理器(以前分别是 irangeiarange)。

顺序的 plate 类似于 range(),因为它生成一系列的值。

向量化的 plate 类似于 torch.arange(),因为它生成一个索引数组,其他张量可以通过这些索引进行访问。platetorch.arange() 的不同之处在于,它还告知推理算法,被索引的变量是条件独立的。为了实现这一点,plate 是作为上下文管理器提供的,而不是一个函数,用户必须保证在 plate 上下文中的所有计算都是条件独立的:

with pyro.plate("name", size) as ind:
    # ...do conditionally independent stuff with ind...

此外,plate 可以通过对索引进行子采样并通知推理算法来缩放各种计算值,从而利用条件独立性假设。这通常用于对数据进行小批量子采样:

with pyro.plate("data", len(data), subsample_size=100) as ind:
    batch = data[ind]
    assert len(batch) == 100

默认情况下,subsample_size=False 这只会生成一个 torch.arange(0, size)。如果 0 < subsample_size <= size 这会生成一个 大小为 subsample_size 的随机批次索引,并在此上下文中将所有 对数似然项缩放为 size/batch_size

警告

这只有在所有计算在上下文中条件独立时才是正确的。

Parameters
  • name (str) – 一个唯一的名称,用于帮助推理算法在模型和指南之间匹配 plate 站点。

  • size (int) – 可选的下采样集合的大小 (类似于内置 range 中的 stop)。

  • subsample_size (int) – 用于子采样的最小批次的大小。 默认为 size

  • subsample(支持len()的任何内容)– 用户自定义子采样方案的可选自定义子样本。如果指定,则subsample_size将被设置为len(subsample)

  • dim (int) – 一个可选的维度,用于此独立性索引。 如果指定,dim 应为负数,即应从右侧索引。如果未指定,dim 将设置为所有封闭 plate 上下文左侧的最右侧维度。

  • use_cuda (bool) – 已弃用,请改用 device 参数。 可选的布尔值,指定是否使用 cuda 张量进行 subsamplelog_prob。默认为 torch.Tensor.is_cuda

  • device (str) – 可选关键字,指定将subsamplelog_prob的结果放置在哪个设备上。默认情况下,结果会放置在默认张量的同一设备上。

Returns

一个可重用的上下文管理器,生成一个一维的 torch.Tensor 索引。

示例

>>> # This version declares sequential independence and subsamples data:
>>> for i in pyro.plate('data', 100, subsample_size=10):
...     if z[i]:  # Control flow in this example prevents vectorization.
...         obs = pyro.sample(f'obs_{i}', dist.Normal(loc, scale),
...                           obs=data[i])
>>> # This version declares vectorized independence:
>>> with pyro.plate('data'):
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data)
>>> # This version subsamples data in vectorized way:
>>> with pyro.plate('data', 100, subsample_size=10) as ind:
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This wraps a user-defined subsampling method for use in pyro:
>>> ind = torch.randint(0, 100, (10,)).long() # custom subsample
>>> with pyro.plate('data', 100, subsample=ind):
...     obs = pyro.sample('obs', dist.Normal(loc, scale), obs=data[ind])
>>> # This reuses two different independence contexts.
>>> x_axis = pyro.plate('outer', 320, dim=-1)
>>> y_axis = pyro.plate('inner', 200, dim=-2)
>>> with x_axis:
...     x_noise = pyro.sample("x_noise", dist.Normal(loc, scale))
...     assert x_noise.shape == (320,)
>>> with y_axis:
...     y_noise = pyro.sample("y_noise", dist.Normal(loc, scale))
...     assert y_noise.shape == (200, 1)
>>> with x_axis, y_axis:
...     xy_noise = pyro.sample("xy_noise", dist.Normal(loc, scale))
...     assert xy_noise.shape == (200, 320)

请参阅SVI Part II以获取更详细的讨论。

plate_stack(prefix: str, sizes: Sequence[int], rightmost_dim: int = - 1) Iterator[None][源代码]

创建一个连续的plate堆栈,其维度为:

rightmost_dim - len(sizes), ..., rightmost_dim
Parameters
  • prefix (str) – 板块的名称前缀。

  • sizes (可迭代对象) – 一个包含板块大小的可迭代对象。

  • rightmost_dim (int) – 从右边开始计数的最右边的维度。

module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module[源代码]

torch.nn.Module的所有参数注册到Pyro的param_store中。结合ParamStoreDictsave()load()功能,这允许用户保存和加载模块。

注意

考虑使用PyroModule,这是一个较新的替代方案,相比pyro.module(),它在以下方面有更好的支持:即时编译(jitting)、在C++中服务以及将参数转换为随机变量。详情请参阅模块教程

Parameters
  • name (str) – 模块的名称

  • nn_module (torch.nn.Module) – 要注册到Pyro的模块

  • update_module_params – 确定是否用ParamStore中找到的值覆盖PyTorch模块中的参数(如果有的话)。默认为False

Returns

torch.nn.Module

random_module(name, nn_module, prior, *args, **kwargs)[source]

警告

random_module 原语已被弃用,并将在未来的版本中移除。请改用 PyroModule 来从 torch.nn.Module 实例创建贝叶斯模块。有关示例,请参阅 贝叶斯回归教程

已弃用 在模块 nn_module 的参数上放置先验。 返回一个可调用的 nn.Module 分布,调用时 返回一个采样的 nn.Module

Parameters
  • name (str) – pyro模块的名称

  • nn_module (torch.nn.Module) – 要注册到pyro的模块

  • prior – pyro 分布、随机函数或带有参数名称作为键和相应分布/随机函数作为值的 python 字典。

Returns

一个可调用的函数,返回一个采样模块

barrier(data: torch.Tensor) torch.Tensor[source]

实验性 确保data中的所有值都是具体的,而不是惰性的funsor值。这与pyro.poutine.collapse()结合使用时非常有用。

enable_validation(is_validate: bool = True) None[source]

启用或禁用Pyro中的验证检查。验证检查提供了有用的警告和错误,例如NaN检查、验证分布参数和支持值、检测ELBO和MCMC的错误使用。由于其中一些检查可能很昂贵,您可能希望禁用成熟模型的验证以加快推理速度。

默认行为模仿了Python的assert语句:默认情况下验证是开启的,但如果Python以优化模式运行(通过python -O),则会被禁用。同样地,默认行为取决于Python的全局__debug__值,通过pyro.enable_validation(__debug__)来实现。

在JIT编译期间,所有支持PyTorch JIT的推理算法暂时禁用验证。我们建议使用非JIT编译的推理算法开发模型以便于调试,然后在模型正确后选择性地切换到JIT编译的推理算法。

Parameters

is_validate (bool) – (可选;默认为 True) 是否启用验证检查。

validation_enabled(is_validate: bool = True) Iterator[None][source]

上下文管理器,在临时启用/禁用验证检查时非常有用。

Parameters

is_validate (bool) – (可选;默认为 True) 临时验证检查覆盖。

trace(fn=None, ignore_warnings=False, jit_options=None)[source]

torch.jit.trace() 的惰性替代方案,适用于调用 pyro.param() 的 Pyro 函数。

实际的编译产物存储在输出的compiled属性中。在此属性上调用诊断方法。

示例:

def model(x):
    scale = pyro.param("scale", torch.tensor(0.5), constraint=constraints.positive)
    return pyro.sample("y", dist.Normal(x, scale))

@pyro.ops.jit.trace
def model_log_prob_fn(x, y):
    cond_model = pyro.condition(model, data={"y": y})
    tr = pyro.poutine.trace(cond_model).get_trace(x)
    return tr.log_prob_sum()
Parameters
  • fn (可调用) – 要跟踪的函数。

  • ignore_warnins (bool) – 是否忽略jit警告。

  • jit_options (dict) – 可选的字典选项,传递给 torch.jit.trace(),例如 {"optimize": False}