参数

Pyro中的参数基本上是围绕PyTorch张量的薄包装器,这些张量带有唯一的名称。 因此,参数是Pyro中的主要状态对象。用户通常通过Pyro原语pyro.param与参数交互。 参数在随机变分推断中扮演着核心角色,它们用于表示参数化模型和指南中参数的点估计。

参数存储

class StateDict[source]

基础:typing_extensions.TypedDict

params: Dict[str, torch.Tensor]
constraints: Dict[str, torch.distributions.constraints.Constraint]
class ParamStoreDict[source]

基础类:object

Pyro中的全局参数存储。这基本上是一个键值存储。 典型用户主要通过原始pyro.param与ParamStore进行交互。

请参阅介绍以获取更多讨论,以及SVI第一部分以查看一些示例。

在使用Pyro中的参数时,需要记住以下几点:

  • 参数必须分配唯一的名称

  • init_tensor 参数在 pyro.param 中仅用于首次将给定(命名的)参数注册到 Pyro 时。

  • 出于这个原因,用户可能需要在REPL中使用clear()方法以获得所需的行为。也可以通过pyro.clear_param_store()调用此方法。

  • 在PyTorch nn.Module中注册的参数的内部名称会以模块的Pyro名称作为前缀。因此,用户可以拥有两个不同的模块,每个模块都包含一个名为weight的参数。相比之下,用户只能拥有一个名为weight的顶级参数(在任何模块之外)。

  • 参数可以使用saveload从磁盘保存和加载。

  • 通常,参数与受约束的不受约束的值相关联。例如,在内部,一个被约束为正数的参数在日志空间中表示为一个不受约束的张量。

clear() None[source]

清除参数存储

items() Iterator[Tuple[str, torch.Tensor]][source]

遍历 (name, constrained_param) 对。请注意,constrained_param 是在受约束的(即面向用户的)空间中。

keys() KeysView[str][source]

遍历参数名称。

values() Iterator[torch.Tensor][source]

遍历受约束的参数值。

setdefault(name: str, init_constrained_value: Union[torch.Tensor, Callable[[], torch.Tensor]], constraint: torch.distributions.constraints.Constraint = Real()) torch.Tensor[source]

ParamStoreDict中检索一个受约束的参数值(如果存在),否则设置初始值。请注意,这比dict.setdefault()稍微复杂一些。

如果参数已经存在,init_constrained_tensor 将被忽略。为了避免昂贵的 init_constrained_tensor 创建,你可以将其包装在一个 lambda 中,该 lambda 仅在参数不存在时才会被评估:

param_store.get("foo", lambda: (0.001 * torch.randn(1000, 1000)).exp(),
                constraint=constraints.positive)
Parameters
  • name (str) – 参数名称

  • init_constrained_value (torch.Tensor返回 torch.Tensor 的可调用对象) – 初始约束值

  • 约束 (Constraint) – torch 约束对象

Returns

约束参数值

Return type

torch.Tensor

named_parameters() ItemsView[str, torch.Tensor][source]

返回一个迭代器,遍历ParamStore中的每个参数的(name, unconstrained_value)元组。请注意,如果参数受到约束,unconstrained_value是在约束隐式使用的无约束空间中的值。

get_all_param_names() KeysView[str][source]
replace_param(param_name: str, new_param: torch.Tensor, old_param: torch.Tensor) None[source]
get_param(name: str, init_tensor: Optional[torch.Tensor] = None, constraint: torch.distributions.constraints.Constraint = Real(), event_dim: Optional[int] = None) torch.Tensor[source]

根据名称获取参数。如果该参数在ParamStore中尚不存在,它将被创建并存储。 Pyro原语pyro.param会调用此方法。

Parameters
Returns

参数

Return type

torch.Tensor

match(name: str) Dict[str, torch.Tensor][source]

获取所有匹配正则表达式的参数。参数必须存在。

Parameters

name (str) – 正则表达式

Returns

带有键参数名称和值torch张量的字典

param_name(p: torch.Tensor) Optional[str][source]

从参数中获取参数名称

Parameters

p – 参数

Returns

参数名称

get_state() pyro.params.param_store.StateDict[source]

获取ParamStore状态。

set_state(state: pyro.params.param_store.StateDict) None[source]

使用之前get_state()调用的状态来设置ParamStore状态

save(filename: str) None[source]

将参数保存到文件

Parameters

文件名 (str) – 保存到的文件名

load(filename: str, map_location: Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] = None) None[source]

从文件加载参数

注意

如果对从磁盘加载的参数使用pyro.module(),请确保设置update_module_params标志:

pyro.get_param_store().load('saved_params.save')
pyro.module('module', nn, update_module_params=True)
Parameters
  • filename (str) – 要加载的文件名

  • map_location (function, torch.device, string字典) – 指定如何重新映射存储位置

scope(state: Optional[pyro.params.param_store.StateDict] = None) Iterator[pyro.params.param_store.StateDict][source]

用于在同一进程中使用多个参数存储的上下文管理器。

这是对get_state()clear()set_state()的一个轻量级封装。对于内存空间有限的大型模型,您可能希望手动使用save()clear()load()

示例用法:

param_store = pyro.get_param_store()

# Train multiple models, while avoiding param name conflicts.
with param_store.scope() as scope1:
    # ...Train one model,guide pair...
with param_store.scope() as scope2:
    # ...Train another model,guide pair...

# Now evaluate each, still avoiding name conflicts.
with param_store.scope(scope1):  # loads the first model's scope
   # ...evaluate the first model...
with param_store.scope(scope2):  # loads the second model's scope
   # ...evaluate the second model...
param_with_module_name(pyro_name: str, param_name: str) str[source]
module_from_param_with_module_name(param_name: str) str[source]
user_param_name(param_name: str) str[source]
normalize_param_name(name: str) str[source]