自动名称生成¶
pyro.contrib.autoname
模块提供了自动生成唯一且语义上有意义的样本站点名称的工具。
- scope(fn=None, prefix=None, inner=None)[source]¶
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
prefix – 一个字符串,用于添加到样本名称前(如果提供了
fn
,则为可选)inner – 用于确定重复名称计数器出现位置的开关
- Returns
fn
使用ScopeMessenger
装饰
scope
在 Pyro 示例站点中,将前缀加上/
添加到名称前。 它的工作方式类似于 TensorFlow 的name_scope
和variable_scope
, 并且可以用作上下文管理器、装饰器或高阶函数。scope
对于将组合模型与指南或数据对齐非常有用。示例:
>>> @scope(prefix="a") ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
示例:
>>> def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
作用域按预期组合,外部作用域在名称中出现在内部作用域之前:
>>> @scope(prefix="b") ... def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "b/a/x" in poutine.trace(model).get_trace()
当用作装饰器或高阶函数时,如果没有提供用户指定的前缀,
scope
将使用输入函数的名称作为前缀。示例:
>>> @scope ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "model/x" in poutine.trace(model).get_trace()
- name_count(fn=None)[source]¶
name_count
是一个非常简单的自动命名方案,它只是在执行过程中多次出现的任何名称后附加一个后缀 “__” 加上一个计数器。只有重复的名称实例才会获得后缀;第一个实例不会被修改。示例:
>>> @name_count ... def model(): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "x" in poutine.trace(model).get_trace() >>> assert "x__1" in poutine.trace(model).get_trace() >>> assert "x__2" in poutine.trace(model).get_trace()
name_count
还与scope()
组合使用,通过添加后缀来复制作用域入口:示例:
>>> @name_count ... def model(): ... for i in range(3): ... with pyro.contrib.autoname.scope(prefix="a"): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a__1/x" in poutine.trace(model).get_trace() >>> assert "a__2/x" in poutine.trace(model).get_trace()
示例:
>>> @name_count ... def model(): ... with pyro.contrib.autoname.scope(prefix="a"): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a/x__1" in poutine.trace(model).get_trace() >>> assert "a/x__2" in poutine.trace(model).get_trace()
- autoname(fn=None, name=None)[source]¶
AutonameMessenger
的便捷包装器为随机变量分配唯一的名称。
对于新变量,如果给定了声明名称,则使用其声明名称,否则使用分布名称:
sample("x", dist.Bernoulli ... ) # -> x sample(dist.Bernoulli ... ) # -> Bernoulli
对于重复的变量名,将计数器作为后缀附加:
sample(dist.Bernoulli ... ) # -> Bernoulli sample(dist.Bernoulli ... ) # -> Bernoulli1 sample(dist.Bernoulli ... ) # -> Bernoulli2
函数和迭代器可以用作名称作用域:
@autoname def f1(): sample(dist.Bernoulli ... ) @autoname def f2(): f1() # -> f2/f1/Bernoulli f1() # -> f2/f1__1/Bernoulli sample(dist.Bernoulli ... ) # -> f2/Bernoulli @autoname(name="model") def f3(): for i in autoname(range(3), name="time"): # model/time/Bernoulli .. model/time__1/Bernoulli .. model/time__2/Bernoulli sample(dist.Bernoulli ... ) # model/time/f1/Bernoulli .. model/time__1/f1/Bernoulli .. model/time__2/f1/Bernoulli f1()
或者可以使用with语句添加作用域:
def f4(): with autoname(name="prefix"): f1() # -> prefix/f1/Bernoulli f1() # -> prefix/f1__1/Bernoulli sample(dist.Bernoulli ... ) # -> prefix/Bernoulli
- sample(*args)[source]¶
- sample(name: str, fn, *args, **kwargs)
- sample(fn: pyro.distributions.distribution.Distribution, *args, **kwargs)
命名的数据结构¶
pyro.contrib.named
模块是 Pyro 之上的一个薄语法层。它允许 Pyro 模型被编写成看起来像操作 Python 数据结构的程序,例如 latent.x.sample_(...)
,而不是像 x = pyro.sample("x", ...)
这样带有字符串标签语句的程序。
本模块提供了三种容器数据结构 named.Object
、
named.List
和 named.Dict
。这些数据结构旨在相互嵌套。它们共同跟踪每个数据结构中每个数据块的地址,以便该地址可以用作 Pyro 站点。例如:
>>> state = named.Object("state")
>>> print(str(state))
state
>>> z = state.x.y.z # z is just a placeholder.
>>> print(str(z))
state.x.y.z
>>> state.xs = named.List() # Create a contained list.
>>> x0 = state.xs.add()
>>> print(str(x0))
state.xs[0]
>>> state.ys = named.Dict()
>>> foo = state.ys['foo']
>>> print(str(foo))
state.ys['foo']
这些地址现在可以在sample
、observe
和param
语句中使用。这些命名的数据结构甚至提供了别名Pyro语句的原地方法。例如:
>>> state = named.Object("state")
>>> loc = state.loc.param_(torch.zeros(1, requires_grad=True))
>>> scale = state.scale.param_(torch.ones(1, requires_grad=True))
>>> z = state.z.sample_(dist.Normal(loc, scale))
>>> obs = state.x.sample_(dist.Normal(loc, scale), obs=z)
有关如何在模型代码中使用这些内容的更深入示例,请参见 Tree Data 和 Mixture 示例。
作者: Fritz Obermeyer, Alexander Rush
- class Object(name)[source]¶
基础类:
object
用于保存不可变潜在状态的对象。
此对象既可以作为嵌套潜在状态的容器,也可以作为通过named.sample、named.observe或named.param语句替换为张量的占位符。当用作占位符时,Object对象取代了普通pyro.sample语句中的字符串。
- Parameters
name (str) – 对象的名称。
示例:
state = named.Object("state") state.x = 0 state.ys = named.List() state.zs = named.Dict() state.a.b.c.d.e.f.g = 0 # Creates a chain of named.Objects.
警告
此数据结构是一次写入的:可以添加数据但不能更改或删除。尝试更改此数据结构可能会导致静默错误。
- sample_(fn: pyro.distributions.torch_distribution.TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, infer: Optional[pyro.poutine.runtime.InferDict] = None, **kwargs) torch.Tensor ¶
调用随机函数
fn
,根据name
和封闭上下文(例如推理算法)产生额外的副作用。有关讨论,请参见 Introduction to Pyro。
- param_(init_tensor: Optional[Union[torch.Tensor, Callable[[], torch.Tensor]]] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor ¶
将变量保存为参数存储中的参数。 要与参数存储交互或写入磁盘, 请参阅参数。
- Parameters
name (str) – 参数的名称
init_tensor (torch.Tensor 或 可调用) – 初始张量或返回张量的惰性可调用对象。 对于大型张量,编写例如
lambda: torch.randn(100000)
可能更便宜,它只会在初始语句中求值。约束 (torch.distributions.constraints.Constraint) – torch 约束,默认为
constraints.real
。event_dim (int) – (可选)与批处理无关的最右侧维度的数量。此维度左侧的维度将被视为批处理维度;如果参数语句位于子采样的plate中,则参数的相应批处理维度将相应地进行子采样。如果未指定,所有维度将被视为事件维度,并且不会执行子采样。
- Returns
一个受约束的参数。可以通过
pyro.param(...).unconstrained()
访问底层未受约束的参数,其中.unconstrained
是一个弱引用属性。- Return type
- class List(name=None)[source]¶
基础类:
list
类似列表的对象,用于保存不可变的潜在状态。
这必须在构造时给定一个名称:
latent = named.List("root")
或者必须立即存储在
named.Object
中:latent = named.Object("root") latent.xs = named.List() # Must be bound to a Object before use.
警告
此数据结构是只写的:可以添加数据但不能更改或删除。尝试更改此数据结构可能会导致静默错误。
作用域¶
pyro.contrib.autoname.scoping
包含了 pyro.contrib.autoname.scope()
的实现,这是一个用于自动为采样站点的名称添加有语义意义前缀的工具。
- class NameCountMessenger[source]¶
基础类:
pyro.poutine.messenger.Messenger
NameCountMessenger
是pyro.contrib.autoname.name_count()
的实现
- class ScopeMessenger(prefix=None, inner=None)[source]¶
基础类:
pyro.poutine.messenger.Messenger
ScopeMessenger
是pyro.contrib.autoname.scope()
的实现
- scope(fn=None, prefix=None, inner=None)[source]¶
- Parameters
fn – 一个随机函数(包含Pyro原始调用的可调用对象)
prefix – 一个字符串,用于添加到样本名称前(如果提供了
fn
,则为可选)inner – 用于确定重复名称计数器出现位置的开关
- Returns
fn
使用ScopeMessenger
装饰
scope
在 Pyro 示例站点中,将前缀加上/
添加到名称前。 它的工作方式类似于 TensorFlow 的name_scope
和variable_scope
, 并且可以用作上下文管理器、装饰器或高阶函数。scope
对于将组合模型与指南或数据对齐非常有用。示例:
>>> @scope(prefix="a") ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
示例:
>>> def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace()
作用域按预期组合,外部作用域在名称中出现在内部作用域之前:
>>> @scope(prefix="b") ... def model(): ... with scope(prefix="a"): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "b/a/x" in poutine.trace(model).get_trace()
当用作装饰器或高阶函数时,如果没有提供用户指定的前缀,
scope
将使用输入函数的名称作为前缀。示例:
>>> @scope ... def model(): ... return pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "model/x" in poutine.trace(model).get_trace()
- name_count(fn=None)[source]¶
name_count
是一个非常简单的自动命名方案,它只是在执行过程中多次出现的任何名称后附加一个后缀 “__” 加上一个计数器。只有重复的名称实例才会获得后缀;第一个实例不会被修改。示例:
>>> @name_count ... def model(): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "x" in poutine.trace(model).get_trace() >>> assert "x__1" in poutine.trace(model).get_trace() >>> assert "x__2" in poutine.trace(model).get_trace()
name_count
还与scope()
组合使用,通过添加后缀来复制作用域入口:示例:
>>> @name_count ... def model(): ... for i in range(3): ... with pyro.contrib.autoname.scope(prefix="a"): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a__1/x" in poutine.trace(model).get_trace() >>> assert "a__2/x" in poutine.trace(model).get_trace()
示例:
>>> @name_count ... def model(): ... with pyro.contrib.autoname.scope(prefix="a"): ... for i in range(3): ... pyro.sample("x", dist.Bernoulli(0.5)) ... >>> assert "a/x" in poutine.trace(model).get_trace() >>> assert "a/x__1" in poutine.trace(model).get_trace() >>> assert "a/x__2" in poutine.trace(model).get_trace()