自动名称生成

pyro.contrib.autoname 模块提供了自动生成唯一且语义上有意义的样本站点名称的工具。

scope(fn=None, prefix=None, inner=None)[source]
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • prefix – 一个字符串,用于添加到样本名称前(如果提供了 fn,则为可选)

  • inner – 用于确定重复名称计数器出现位置的开关

Returns

fn 使用 ScopeMessenger 装饰

scope 在 Pyro 示例站点中,将前缀加上 / 添加到名称前。 它的工作方式类似于 TensorFlow 的 name_scopevariable_scope, 并且可以用作上下文管理器、装饰器或高阶函数。

scope 对于将组合模型与指南或数据对齐非常有用。

示例:

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

示例:

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

作用域按预期组合,外部作用域在名称中出现在内部作用域之前:

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

当用作装饰器或高阶函数时,如果没有提供用户指定的前缀,scope 将使用输入函数的名称作为前缀。

示例:

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count 是一个非常简单的自动命名方案,它只是在执行过程中多次出现的任何名称后附加一个后缀 “__” 加上一个计数器。只有重复的名称实例才会获得后缀;第一个实例不会被修改。

示例:

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count 还与 scope() 组合使用,通过添加后缀来复制作用域入口:

示例:

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

示例:

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()
autoname(fn=None, name=None)[source]

AutonameMessenger 的便捷包装器

为随机变量分配唯一的名称。

  1. 对于新变量,如果给定了声明名称,则使用其声明名称,否则使用分布名称:

    sample("x", dist.Bernoulli ... )  # -> x
    sample(dist.Bernoulli ... )  # -> Bernoulli
    
  2. 对于重复的变量名,将计数器作为后缀附加:

    sample(dist.Bernoulli ... )  # -> Bernoulli
    sample(dist.Bernoulli ... )  # -> Bernoulli1
    sample(dist.Bernoulli ... )  # -> Bernoulli2
    
  3. 函数和迭代器可以用作名称作用域:

    @autoname
    def f1():
        sample(dist.Bernoulli ... )
    
    @autoname
    def f2():
        f1()  # -> f2/f1/Bernoulli
        f1()  # -> f2/f1__1/Bernoulli
        sample(dist.Bernoulli ... )  # -> f2/Bernoulli
    
    @autoname(name="model")
    def f3():
        for i in autoname(range(3), name="time"):
            # model/time/Bernoulli .. model/time__1/Bernoulli .. model/time__2/Bernoulli
            sample(dist.Bernoulli ... )
            # model/time/f1/Bernoulli .. model/time__1/f1/Bernoulli .. model/time__2/f1/Bernoulli
            f1()
    
  4. 或者可以使用with语句添加作用域:

    def f4():
        with autoname(name="prefix"):
            f1()  # -> prefix/f1/Bernoulli
            f1()  # -> prefix/f1__1/Bernoulli
            sample(dist.Bernoulli ... )  # -> prefix/Bernoulli
    
sample(*args)[source]
sample(name: str, fn, *args, **kwargs)
sample(fn: pyro.distributions.distribution.Distribution, *args, **kwargs)

命名的数据结构

pyro.contrib.named 模块是 Pyro 之上的一个薄语法层。它允许 Pyro 模型被编写成看起来像操作 Python 数据结构的程序,例如 latent.x.sample_(...),而不是像 x = pyro.sample("x", ...) 这样带有字符串标签语句的程序。

本模块提供了三种容器数据结构 named.Objectnamed.Listnamed.Dict。这些数据结构旨在相互嵌套。它们共同跟踪每个数据结构中每个数据块的地址,以便该地址可以用作 Pyro 站点。例如:

>>> state = named.Object("state")
>>> print(str(state))
state

>>> z = state.x.y.z  # z is just a placeholder.
>>> print(str(z))
state.x.y.z

>>> state.xs = named.List()  # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]

>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']

这些地址现在可以在sampleobserveparam语句中使用。这些命名的数据结构甚至提供了别名Pyro语句的原地方法。例如:

>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)

有关如何在模型代码中使用这些内容的更深入示例,请参见 Tree DataMixture 示例。

作者: Fritz Obermeyer, Alexander Rush

class Object(name)[source]

基础类:object

用于保存不可变潜在状态的对象。

此对象既可以作为嵌套潜在状态的容器,也可以作为通过named.sample、named.observe或named.param语句替换为张量的占位符。当用作占位符时,Object对象取代了普通pyro.sample语句中的字符串。

Parameters

name (str) – 对象的名称。

示例:

state = named.Object("state")
state.x = 0
state.ys = named.List()
state.zs = named.Dict()
state.a.b.c.d.e.f.g = 0  # Creates a chain of named.Objects.

警告

此数据结构是一次写入的:可以添加数据但不能更改或删除。尝试更改此数据结构可能会导致静默错误。

sample_(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor

调用随机函数 fn,根据 name 和封闭上下文(例如推理算法)产生额外的副作用。有关讨论,请参见 Introduction to Pyro

Parameters
  • name – 样本的名称

  • fn – 分布类或函数

  • obs – 观测数据(可选;仅在推理上下文中使用)可选地在kwargs中指定

  • obs_mask (boolTensor) – 可选的布尔张量掩码,形状与 fn.batch_shape 可广播。如果提供,具有 mask=True 的事件将根据 obs 进行条件处理,其余事件将通过采样进行填补。这会引入一个名为 name + "_unobserved" 的潜在样本站点,应由指南使用。

  • infer (dict) – 可选的推理参数字典,通过kwargs指定。详情请参阅推理文档。

Returns

样本

param_(init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor

将变量保存为参数存储中的参数。 要与参数存储交互或写入磁盘, 请参阅参数

Parameters
  • name (str) – 参数的名称

  • init_tensor (torch.Tensor可调用) – 初始张量或返回张量的惰性可调用对象。 对于大型张量,编写例如 lambda: torch.randn(100000) 可能更便宜,它只会在初始语句中求值。

  • 约束 (torch.distributions.constraints.Constraint) – torch 约束,默认为 constraints.real

  • event_dim (int) – (可选)与批处理无关的最右侧维度的数量。此维度左侧的维度将被视为批处理维度;如果参数语句位于子采样的plate中,则参数的相应批处理维度将相应地进行子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。

Returns

一个受约束的参数。可以通过pyro.param(...).unconstrained()访问底层未受约束的参数,其中.unconstrained是一个弱引用属性。

Return type

torch.Tensor

class List(name=None)[source]

基础类: list

类似列表的对象,用于保存不可变的潜在状态。

这必须在构造时给定一个名称:

latent = named.List("root")

或者必须立即存储在named.Object中:

latent = named.Object("root")
latent.xs = named.List()  # Must be bound to a Object before use.

警告

此数据结构是只写的:可以添加数据但不能更改或删除。尝试更改此数据结构可能会导致静默错误。

add()[source]

追加一个新的命名对象。

Returns

在末尾的一个新的潜在对象

Return type

named.Object

class Dict(name=None)[source]

基础类:dict

类似字典的对象,用于保存不可变的潜在状态。

这必须在构造时给定一个名称:

latent = named.Dict("root")

或者必须立即存储在 named.Object 中:

latent = named.Object("root")
latent.xs = named.Dict()  # Must be bound to a Object before use.

警告

此数据结构是只写的:可以添加数据但不能更改或删除。尝试更改此数据结构可能会导致静默错误。

作用域

pyro.contrib.autoname.scoping 包含了 pyro.contrib.autoname.scope() 的实现,这是一个用于自动为采样站点的名称添加有语义意义前缀的工具。

class NameCountMessenger[source]

基础类: pyro.poutine.messenger.Messenger

NameCountMessengerpyro.contrib.autoname.name_count() 的实现

class ScopeMessenger(prefix=None, inner=None)[source]

基础类: pyro.poutine.messenger.Messenger

ScopeMessengerpyro.contrib.autoname.scope() 的实现

scope(fn=None, prefix=None, inner=None)[source]
Parameters
  • fn – 一个随机函数(包含Pyro原始调用的可调用对象)

  • prefix – 一个字符串,用于添加到样本名称前(如果提供了 fn,则为可选)

  • inner – 用于确定重复名称计数器出现位置的开关

Returns

fn 使用 ScopeMessenger 装饰

scope 在 Pyro 示例站点中,将前缀加上 / 添加到名称前。 它的工作方式类似于 TensorFlow 的 name_scopevariable_scope, 并且可以用作上下文管理器、装饰器或高阶函数。

scope 对于将组合模型与指南或数据对齐非常有用。

示例:

>>> @scope(prefix="a")
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

示例:

>>> def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()

作用域按预期组合,外部作用域在名称中出现在内部作用域之前:

>>> @scope(prefix="b")
... def model():
...     with scope(prefix="a"):
...         return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "b/a/x" in poutine.trace(model).get_trace()

当用作装饰器或高阶函数时,如果没有提供用户指定的前缀,scope 将使用输入函数的名称作为前缀。

示例:

>>> @scope
... def model():
...     return pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "model/x" in poutine.trace(model).get_trace()
name_count(fn=None)[source]

name_count 是一个非常简单的自动命名方案,它只是在执行过程中多次出现的任何名称后附加一个后缀 “__” 加上一个计数器。只有重复的名称实例才会获得后缀;第一个实例不会被修改。

示例:

>>> @name_count
... def model():
...     for i in range(3):
...         pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "x" in poutine.trace(model).get_trace()
>>> assert "x__1" in poutine.trace(model).get_trace()
>>> assert "x__2" in poutine.trace(model).get_trace()

name_count 还与 scope() 组合使用,通过添加后缀来复制作用域入口:

示例:

>>> @name_count
... def model():
...     for i in range(3):
...         with pyro.contrib.autoname.scope(prefix="a"):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a__1/x" in poutine.trace(model).get_trace()
>>> assert "a__2/x" in poutine.trace(model).get_trace()

示例:

>>> @name_count
... def model():
...     with pyro.contrib.autoname.scope(prefix="a"):
...         for i in range(3):
...             pyro.sample("x", dist.Bernoulli(0.5))
...
>>> assert "a/x" in poutine.trace(model).get_trace()
>>> assert "a/x__1" in poutine.trace(model).get_trace()
>>> assert "a/x__2" in poutine.trace(model).get_trace()