基于Funsor的Pyro

基本元素

clear_param_store() None[source]

清除全局的ParamStoreDict

这在REPL环境中尤其有用。我们建议在每次训练循环之前调用此函数(以避免从过去的模型中泄露参数),并在每个单元测试之前调用(以避免在测试之间泄露参数)。

condition(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Union[Dict[str, torch.Tensor], Trace]) Union[pyro.poutine.condition_messenger.ConditionMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]][source]

ConditionMessenger的便捷包装器

给定一个包含一些样本语句的随机函数和一个名称的观察字典, 将那些名称处的样本语句更改为具有这些值的观察。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要观察站点 z 的值,我们可以写

>>> conditioned_model = pyro.poutine.condition(model, data={"z": torch.tensor(1.)})

这相当于在model中将obs=value作为关键字参数添加到pyro.sample(“z”, …)中。

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • data – 一个字典或一个 Trace

Returns

使用ConditionMessenger装饰的随机函数

deterministic(name: str, value: torch.Tensor, event_dim: Optional[int] = None) torch.Tensor[source]

确定性语句,用于向跟踪中添加一个名为name且值为valueDelta站点。当我们想要记录完全由其父节点确定的值时,这非常有用。例如:

x = pyro.sample("x", dist.Normal(0, 1))
x2 = pyro.deterministic("x2", x ** 2)

注意

该站点不影响模型密度。目前这转换为一个sample()语句,但未来可能会改变。

Parameters
  • name (str) – 站点的名称。

  • value (torch.Tensor) – 站点的值。

  • event_dim (int) – 可选的事件维度,默认为 value.ndim

do(fn: Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T], data: Dict[str, Union[torch.Tensor, numbers.Number]]) Union[pyro.poutine.do_messenger.DoMessenger, Callable[[pyro.poutine.handlers._P], pyro.poutine.handlers._T]][source]

DoMessenger 的便捷包装器

给定一个带有一些样本语句的随机函数和一个名称值的字典, 将这些站点的返回值设置为这些值, 就像它们被硬编码为这些值一样, 并引入具有相同名称的新样本站点, 这些站点的值不会传播。

condition()自由组合,以表示潜在结果的反事实分布。有关更多细节和理论,请参见单世界干预图[1]。

考虑以下Pyro程序:

>>> def model(x):
...     s = pyro.param("s", torch.tensor(0.5))
...     z = pyro.sample("z", dist.Normal(x, s))
...     return z ** 2

要为站点 z 干预一个值,我们可以写

>>> intervened_model = pyro.poutine.do(model, data={"z": torch.tensor(1.)})

这相当于将z = pyro.sample(“z”, …)替换为 z = torch.tensor(1.) 并引入一个新的采样点pyro.sample(“z”, …),其值在其他地方未使用。

参考文献

[1] Single World Intervention Graphs: A Primer,

托马斯·理查森,詹姆斯·罗宾斯

Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • data – 一个将样本站点名称映射到干预措施的 dict

Returns

使用DoMessenger装饰的随机函数

enable_validation(is_validate: bool = True) None[source]

启用或禁用Pyro中的验证检查。验证检查提供了有用的警告和错误,例如NaN检查、验证分布参数和支持值、检测ELBO和MCMC的错误使用。由于其中一些检查可能很昂贵,您可能希望禁用成熟模型的验证以加快推理速度。

默认行为模仿了Python的assert语句:默认情况下验证是开启的,但如果Python以优化模式运行(通过python -O),则会被禁用。同样地,默认行为取决于Python的全局__debug__值,通过pyro.enable_validation(__debug__)来实现。

在JIT编译期间,所有支持PyTorch JIT的推理算法暂时禁用验证。我们建议使用非JIT编译的推理算法开发模型以便于调试,然后在模型正确后选择性地切换到JIT编译的推理算法。

Parameters

is_validate (bool) – (可选;默认为 True) 是否启用验证检查。

factor(name: str, log_factor: torch.Tensor, *, has_rsample: Optional[bool] = None) None[source]

因子语句,用于向概率模型添加任意对数概率因子。

警告

在指南中使用因子语句时,您需要指定因子语句是来自完全重新参数化的采样(例如,重新参数化变量变换的雅可比行列式)还是来自非重新参数化的采样(例如,离散样本)。对于完全重新参数化的情况,设置has_rsample=True;对于非重新参数化的情况,设置has_rsample=False。这仅在指南中需要,而不是在模型中。

Parameters
  • name (str) – 简单样本的名称

  • log_factor (torch.Tensor) – 一个可能是批处理的对数概率因子。

  • has_rsample (bool) – 是否log_factor来自完全重新参数化的分布。在模型中使用时默认为False,但在指南中使用时必须指定。

get_param_store() pyro.params.param_store.ParamStoreDict[source]

返回全局的ParamStoreDict

markov(fn=None, history=1, keep=False)[源代码]

MarkovMessenger的便捷包装器

用于与Pyro的位置批处理维度一致的funsors之间转换的处理程序。

Parameters
  • 历史 (int) – 从当前上下文中可见的先前上下文的数量。默认为1。如果为零,这类似于 pyro.plate

  • keep (bool) – 如果为真,帧是可重放的。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(取决于它们的共同祖先)。

module(name: str, nn_module: torch.nn.modules.module.Module, update_module_params: bool = False) torch.nn.modules.module.Module[源代码]

torch.nn.Module的所有参数注册到Pyro的param_store中。结合ParamStoreDictsave()load()功能,这允许用户保存和加载模块。

注意

考虑使用PyroModule,这是一个较新的替代方案,相比pyro.module(),它在以下方面有更好的支持:即时编译(jitting)、在C++中服务以及将参数转换为随机变量。详情请参阅模块教程

Parameters
  • name (str) – 模块的名称

  • nn_module (torch.nn.Module) – 要注册到Pyro的模块

  • update_module_params – 确定是否用ParamStore中找到的值覆盖PyTorch模块中的参数(如果有的话)。默认为False

Returns

torch.nn.Module

param(name: str, init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

将变量保存为参数存储中的参数。 要与参数存储交互或写入磁盘, 请参阅参数

Parameters
  • name (str) – 参数的名称

  • init_tensor (torch.Tensor可调用) – 初始张量或返回张量的惰性可调用对象。 对于大型张量,编写例如 lambda: torch.randn(100000) 可能更便宜,它只会在初始语句中求值。

  • 约束 (torch.distributions.constraints.Constraint) – torch 约束,默认为 constraints.real

  • event_dim (int) – (可选)与批处理无关的最右侧维度的数量。此维度左侧的维度将被视为批处理维度;如果参数语句位于子采样的plate中,则参数的相应批处理维度将相应地进行子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。

Returns

一个受约束的参数。可以通过pyro.param(...).unconstrained()访问底层未受约束的参数,其中.unconstrained是一个弱引用属性。

Return type

torch.Tensor

random_module(name, nn_module, prior, *args, **kwargs)[source]

警告

random_module 原语已被弃用,并将在未来的版本中移除。请改用 PyroModule 来从 torch.nn.Module 实例创建贝叶斯模块。有关示例,请参阅 贝叶斯回归教程

已弃用 在模块 nn_module 的参数上放置先验。 返回一个可调用的 nn.Module 分布,调用时 返回一个采样的 nn.Module

Parameters
  • name (str) – pyro模块的名称

  • nn_module (torch.nn.Module) – 要注册到pyro的模块

  • prior – pyro 分布、随机函数或带有参数名称作为键和相应分布/随机函数作为值的 python 字典。

Returns

一个可调用的函数,返回一个采样模块

sample(name: str, fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor[source]

调用随机函数 fn,根据 name 和封闭上下文(例如推理算法)产生额外的副作用。有关讨论,请参见 Introduction to Pyro

Parameters
  • name – 样本的名称

  • fn – 分布类或函数

  • obs – 观测数据(可选;仅在推理上下文中使用)可选地在kwargs中指定

  • obs_mask (boolTensor) – 可选的布尔张量掩码,形状与 fn.batch_shape 可广播。如果提供,具有 mask=True 的事件将根据 obs 进行条件处理,其余事件将通过采样进行填补。这会引入一个名为 name + "_unobserved" 的潜在样本站点,应由指南使用。

  • infer (dict) – 可选的推理参数字典,通过kwargs指定。详情请参阅推理文档。

Returns

样本

set_rng_seed(rng_seed)[source]

设置torchtorch.cuda(如果可用)的种子。

Parameters

rng_seed (int) – 种子值。

subsample(data: torch.Tensor, event_dim: int) torch.Tensor[source]

子采样语句,用于基于封闭的plate对数据张量进行子采样。

这通常在自动通过plate进行子采样时,通过传递subsamplesubsample_size关键字参数来调用model()的参数。例如,以下是等价的:

# Version 1. using indexing
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as ind:
        data = data[ind]
        # ...

# Version 2. using pyro.subsample()
def model(data):
    with pyro.plate("data", len(data), subsample_size=10, dim=-data.dim()):
        data = pyro.subsample(data, event_dim=0)
        # ...
Parameters
  • data (Tensor) – 一个批处理数据的张量。

  • event_dim (int) – 数据张量的事件维度。左侧的维度被视为批次维度。

Returns

data 的子采样版本

Return type

Tensor

to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]

VectorizedMarkovMessenger 的便捷封装

为变量设计的马尔可夫链构造,旨在使用并行扫描算法高效消除马尔可夫维度。在允许的情况下,vectorized_markov 可以与 markov 互换使用。

for循环生成int和一维torch.Tensor索引: (0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))int索引用于初始化马尔可夫链,而torch.Tensor索引 用于构建向量化的转移概率,以便通过并行扫描算法进行高效消除。

history==0 时,vectorized_markov 的行为类似于 plate

在for循环运行后,Markov变量被识别,然后step信息被构建并添加到跟踪中。step通知推理算法哪些变量属于Markov链。

data = torch.ones(3, dtype=torch.float)

def model(data, vectorized=True):

    init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex)
    trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(3,))

    markov_chain = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \
        else pyro.markov(range(len(data)))
    for i in markov_chain:
        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(
            init if isinstance(i, int) and i < 1 else trans[x_prev]),

        pyro.sample("y_{}".format(i),
                    dist.Normal(Vindex(locs)[..., x_curr], 1.),
                    obs=data[i])
        x_prev = x_curr

#  trace.nodes["time"]["value"]
#  frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')})
#
#  pyro.vectorized_markov trace
#  ...
#  Sample Sites:
#      locs dist               | 3
#          value               | 3
#       log_prob               |
#       x_0 dist               |
#          value     3 1 1 1 1 |
#       log_prob     3 1 1 1 1 |
#       y_0 dist     3 1 1 1 1 |
#          value               |
#       log_prob     3 1 1 1 1 |
#  x_slice(1, 3, None) dist   3 1 1 1 1 2 |
#          value 3 1 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 2 |
#  y_slice(1, 3, None) dist 3 1 1 1 1 1 2 |
#          value             2 |
#       log_prob 3 1 1 1 1 1 2 |
#
#  pyro.markov trace
#  ...
#  Sample Sites:
#      locs dist             | 3
#          value             | 3
#       log_prob             |
#       x_0 dist             |
#          value   3 1 1 1 1 |
#       log_prob   3 1 1 1 1 |
#       y_0 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |
#       x_1 dist   3 1 1 1 1 |
#          value 3 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_1 dist 3 1 1 1 1 1 |
#          value             |
#       log_prob 3 1 1 1 1 1 |
#       x_2 dist 3 1 1 1 1 1 |
#          value   3 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_2 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |

警告

这仅在每个分支只有一个马尔可夫维度时有效。

Parameters
  • name (str) – 一个唯一的马尔可夫维度名称,用于帮助推理算法消除马尔可夫链中的变量。

  • size (int) – 马尔可夫链的长度(大小)。

  • dim (int) – 一个可选的维度,用于此马尔可夫维度。 如果指定,dim 应为负数,即应从右侧索引。如果未指定,dim 将设置为所有封闭 plate 上下文左侧的最右侧维度。

  • 历史 (int) – 马尔可夫链的记忆(顺序)。也是从当前上下文中可见的先前上下文的数量。默认为1。 如果为零,这类似于 plate

Returns

返回int和一维torch.Tensor索引: (0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))

效果处理器

enum(fn=None, first_available_dim=None)[source]

EnumMessenger 的便捷包装器

此版本的EnumMessenger使用to_data()为每个离散样本站点分配一个新的枚举维度。

markov(fn=None, history=1, keep=False)[源代码]

MarkovMessenger的便捷包装器

用于与Pyro的位置批处理维度一致的funsors之间转换的处理程序。

Parameters
  • 历史 (int) – 从当前上下文中可见的先前上下文的数量。默认为1。如果为零,这类似于 pyro.plate

  • keep (bool) – 如果为真,帧是可重放的。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(取决于它们的共同祖先)。

named(fn=None, first_available_dim=None)[source]

NamedMessenger 的便捷包装器

用于to_funsor()to_data()原语的基础效果处理类。 任何在内部调用这些原语或包装调用这些原语的代码的效果处理程序都应继承自NamedMessenger

这种设计确保了在处理器退出时全局名称-维度映射会被重置,而不是可能持续到整个程序终止。

plate(fn=None, name=None, size=None, subsample_size=None, subsample=None, dim=None, use_cuda=None, device=None)[source]

PlateMessenger的便捷包装器

将新的IndepMessenger实现与现有的pyro.poutine.BroadcastMessenger结合。最终应该可以替代pyro.plate

replay(fn=None, trace=None, params=None)[source]

ReplayMessenger 的便捷包装器

这个版本的 ReplayMessenger 几乎与原始版本相同, 除了它在重放的funsor值上调用 to_data()。 这可能会导致不同的解包形状,但应该会产生正确的分配。

trace(fn=None, graph_type=None, param_only=None, pack_online=True)[source]

TraceMessenger的便捷包装器

设置 pack_online=True 会在线打包而不是事后打包,一旦所有分布和值可用,立即将它们转换为 Funsors。

设置 pack_online=False 计算执行后打包所需的信息。 每个样本站点都带有一个 dim_to_name 字典的注释, 可以直接传递给 to_funsor()

vectorized_markov(fn=None, name=None, size=None, dim=None, history=1)[source]

VectorizedMarkovMessenger 的便捷封装

为变量设计的马尔可夫链构造,旨在使用并行扫描算法高效消除马尔可夫维度。在允许的情况下,vectorized_markov 可以与 markov 互换使用。

for循环生成int和一维torch.Tensor索引: (0, ..., history-1, torch.arange(0, size-history), ..., torch.arange(history, size))int索引用于初始化马尔可夫链,而torch.Tensor索引 用于构建向量化的转移概率,以便通过并行扫描算法进行高效消除。

history==0 时,vectorized_markov 的行为类似于 plate

在for循环运行后,Markov变量被识别,然后step信息被构建并添加到跟踪中。step通知推理算法哪些变量属于Markov链。

data = torch.ones(3, dtype=torch.float)

def model(data, vectorized=True):

    init = pyro.param("init", lambda: torch.rand(3), constraint=constraints.simplex)
    trans = pyro.param("trans", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs = pyro.param("locs", lambda: torch.rand(3,))

    markov_chain = \
        pyro.vectorized_markov(name="time", size=len(data), dim=-1) if vectorized \
        else pyro.markov(range(len(data)))
    for i in markov_chain:
        x_curr = pyro.sample("x_{}".format(i), dist.Categorical(
            init if isinstance(i, int) and i < 1 else trans[x_prev]),

        pyro.sample("y_{}".format(i),
                    dist.Normal(Vindex(locs)[..., x_curr], 1.),
                    obs=data[i])
        x_prev = x_curr

#  trace.nodes["time"]["value"]
#  frozenset({('x_0', 'x_slice(0, 2, None)', 'x_slice(1, 3, None)')})
#
#  pyro.vectorized_markov trace
#  ...
#  Sample Sites:
#      locs dist               | 3
#          value               | 3
#       log_prob               |
#       x_0 dist               |
#          value     3 1 1 1 1 |
#       log_prob     3 1 1 1 1 |
#       y_0 dist     3 1 1 1 1 |
#          value               |
#       log_prob     3 1 1 1 1 |
#  x_slice(1, 3, None) dist   3 1 1 1 1 2 |
#          value 3 1 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 2 |
#  y_slice(1, 3, None) dist 3 1 1 1 1 1 2 |
#          value             2 |
#       log_prob 3 1 1 1 1 1 2 |
#
#  pyro.markov trace
#  ...
#  Sample Sites:
#      locs dist             | 3
#          value             | 3
#       log_prob             |
#       x_0 dist             |
#          value   3 1 1 1 1 |
#       log_prob   3 1 1 1 1 |
#       y_0 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |
#       x_1 dist   3 1 1 1 1 |
#          value 3 1 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_1 dist 3 1 1 1 1 1 |
#          value             |
#       log_prob 3 1 1 1 1 1 |
#       x_2 dist 3 1 1 1 1 1 |
#          value   3 1 1 1 1 |
#       log_prob 3 3 1 1 1 1 |
#       y_2 dist   3 1 1 1 1 |
#          value             |
#       log_prob   3 1 1 1 1 |

警告

这仅在每个分支只有一个马尔可夫维度时有效。

Parameters
  • name (str) – 一个唯一的马尔可夫维度名称,用于帮助推理算法消除马尔可夫链中的变量。

  • size (int) – 马尔可夫链的长度(大小)。

  • dim (int) – 一个可选的维度,用于此马尔可夫维度。 如果指定,dim 应为负数,即应从右侧索引。如果未指定,dim 将设置为所有封闭 plate 上下文左侧的最右侧维度。

  • 历史 (int) – 马尔可夫链的记忆(顺序)。也是从当前上下文中可见的先前上下文的数量。默认为1。 如果为零,这类似于 plate

Returns

返回int和一维torch.Tensor索引: (0, ..., history-1, torch.arange(size-history), ..., torch.arange(history, size))

class NamedMessenger(first_available_dim=None)[源代码]

基础类:pyro.poutine.reentrant_messenger.ReentrantMessenger

用于to_funsor()to_data()原语的基础效果处理程序类。 任何在内部调用这些原语或包装执行这些操作的代码的效果处理程序都应继承自NamedMessenger

这种设计确保了在处理器退出时全局名称-维度映射会被重置,而不是可能持续到整个程序终止。

class MarkovMessenger(history=1, keep=False)[源代码]

基础类:pyro.contrib.funsor.handlers.named_messenger.NamedMessenger

用于与Pyro的位置批处理维度一致的funsors之间转换的处理程序。

Parameters
  • 历史 (int) – 从当前上下文中可见的先前上下文的数量。默认为1。如果为零,这类似于 pyro.plate

  • keep (bool) – 如果为真,帧是可重放的。这在分支时很重要:如果 keep=True,同一级别的相邻分支可以相互依赖;如果 keep=False,相邻分支是独立的(取决于它们的共同祖先)。

class GlobalNamedMessenger(first_available_dim=None)[source]

基础类:pyro.contrib.funsor.handlers.named_messenger.NamedMessenger

任何使用to_funsor()to_data()原语来分配DimType.GLOBALDimType.VISIBLE维度的新效果处理程序的基类。

作为不应被MarkovMessenger回收的维度的手动“作用域”:全局维度将被视为活动状态,直到最初分配它们的GlobalNamedMessenger最内层退出。

to_funsor(x, output=None, dim_to_name=None, dim_type=DimType.LOCAL)[source]
to_data(x, name_to_dim=None, dim_type=DimType.LOCAL)[source]
class StackFrame(name_to_dim, dim_to_name, history=1, keep=False)[source]

基础类:object

整数位置维度和名称之间的一致双向映射。 可以像字典一样查询(value = frame[key], frame[key] = value)。

class DimType(value)[source]

基础类:enum.Enum

列举了分配维度的可能类型

LOCAL = 0
GLOBAL = 1
VISIBLE = 2
class DimRequest(value, dim_type)

基础类:tuple

dim_type

字段编号1的别名

value

字段编号 0 的别名

class DimStack[source]

基础类:object

用于跟踪名称和维度之间映射的全局状态。

替换了板_DimAllocator, 枚举_EnumAllocatorMarkovMessenger中的stackEnumMessenger中的_param_dims_value_dims,以及msg['infer']中的dim_to_symbol

MAX_DIM = -25
DEFAULT_FIRST_DIM = -5
set_first_available_dim(dim)[source]
push_global(frame)[source]
pop_global()[source]
push_iter(frame)[source]
pop_iter()[来源]
push_local(frame)[来源]
pop_local()[来源]
property global_frame
property local_frame
property current_write_env
property current_read_env

收集所有必要的框架以计算全名 <–> 维度映射,并在计算中的任何点解释Funsor输入或批量形状。

allocate(key_to_value_request)[source]
names_from_batch_shape(batch_shape, dim_type=DimType.LOCAL)[源代码]

推理算法

class ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类:pyro.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]
loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss()

loss_and_grads(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.loss_and_grads()

class Jit_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类:pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]
class Trace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类:pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[来源]

参见 pyro.infer.trace_elbo.Trace_ELBO.differentiable_loss()

class JitTrace_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[源代码]

基础类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.trace_elbo.Trace_ELBO

apply_optimizer(x)[source]
terms_from_trace(tr)[source]

辅助函数,用于从执行轨迹中提取elbo组件。

class TraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类:pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()

class TraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类:pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.traceenum_elbo.TraceEnum_ELBO.differentiable_loss()

class JitTraceEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.traceenum_elbo.TraceEnum_ELBO

class JitTraceMarkovEnum_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类: pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.traceenum_elbo.TraceMarkovEnum_ELBO

class TraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[source]

基础类:pyro.contrib.funsor.infer.elbo.ELBO

differentiable_loss(model, guide, *args, **kwargs)[source]

参见 pyro.infer.tracetmc_elbo.TraceTMC_ELBO.differentiable_loss()

class JitTraceTMC_ELBO(num_particles=1, max_plate_nesting=inf, max_iarange_nesting=None, vectorize_particles=False, strict_enumeration_warning=True, ignore_jit_warnings=False, jit_options=None, retain_graph=None, tail_adaptive_beta=- 1.0)[源代码]

基础类:pyro.contrib.funsor.infer.elbo.Jit_ELBO, pyro.contrib.funsor.infer.tracetmc_elbo.TraceTMC_ELBO

infer_discrete(model, first_available_dim=None, temperature=1)[source]