torch.nn.utils.parametrize 的源代码
import torch
from torch.nn.modules.container import ModuleList, ModuleDict, Module
from torch.nn.parameter import Parameter
from torch import Tensor
import collections
import copyreg
from copy import deepcopy
from contextlib import contextmanager
from typing import Union, Optional, Dict, Tuple, Sequence
__all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations',
'type_before_parametrizations', 'transfer_parametrizations_and_params']
_cache_enabled = 0
_cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
[docs]@contextmanager
def cached():
r"""上下文管理器,启用通过 :func:`register_parametrization` 注册的参数化中的缓存系统。
当此上下文管理器处于活动状态时,参数化对象的值在第一次需要时计算并缓存。离开上下文管理器时,缓存的值将被丢弃。
这在正向传递中多次使用参数化参数时很有用。例如,当参数化RNN的循环核或共享权重时。
激活缓存的最简单方法是将神经网络的正向传递包装起来:
.. code-block:: python
import torch.nn.utils.parametrize as P
...
with P.cached():
output = model(inputs)
在训练和评估中。也可以包装使用参数化张量多次的模块部分。例如,具有参数化循环核的RNN的循环:
.. code-block:: python
with P.cached():
for x in xs:
out_rnn = self.rnn_cell(x, out_rnn)
"""
global _cache
global _cache_enabled
_cache_enabled += 1
try:
yield
finally:
_cache_enabled -= 1
if not _cache_enabled:
_cache = {}
def _register_parameter_or_buffer(module, name, X):
if isinstance(X, Parameter):
module.register_parameter(name, X)
else:
module.register_buffer(name, X)
[docs]class ParametrizationList(ModuleList):
r"""一个顺序容器,用于保存和管理参数化 :class:`torch.nn.Module` 的原始参数或缓冲区。
当 ``module[tensor_name]`` 通过 :func:`register_parametrization` 参数化时,它是 ``module.parametrizations[tensor_name]`` 的类型。
如果第一个注册的参数化具有返回一个张量的 ``right_inverse``,或者没有 ``right_inverse``(在这种情况下我们假设 ``right_inverse`` 是恒等函数),
它将持有名为 ``original`` 的张量。
如果它有一个返回多个张量的 ``right_inverse``,这些将被注册为 ``original0``, ``original1``, ...
.. 警告::
该类由 :func:`register_parametrization` 内部使用。为了完整性,这里进行了文档化。用户不应实例化它。
参数:
modules (sequence): 表示参数化的模块序列
original (Parameter or Tensor): 被参数化的参数或缓冲区
unsafe (bool): 一个布尔标志,表示参数化是否可能改变张量的 dtype 和形状。默认值: `False`
警告: 在注册时不会检查参数化的一致性。启用此标志需自行承担风险。
"""
original: Tensor
unsafe: bool
def __init__(
self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False
) -> None:
# 我们需要这个,因为我们需要以不同的方式处理第一个参数化
# 除非从外部使用此类的实例,否则这应该永远不会抛出
if len(modules) == 0:
raise ValueError("ParametrizationList 需要一个或多个模块。")
super().__init__(modules)
self.unsafe = unsafe
# 简单来说:
# module.weight 必须保持其 dtype 和形状。
# 此外,如果没有 right_inverse 或 right_inverse 返回一个张量,
# 这应该与原始张量具有相同的 dtype
#
# 我们检查以下不变量是否成立:
# X = module.weight
# Y = param.right_inverse(X)
# assert isinstance(Y, Tensor) or
# (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
# Z = param(Y) if isinstance(Y, Tensor) else param(*Y)
# # 一致性检查
# assert X.dtype == Z.dtype and X.shape == Z.shape
# # 如果它有一个输入,这允许能够使用 set_ 将数据移动到/从原始张量而不改变其 id(这是优化器用来跟踪参数的)
# if isinstance(Y, Tensor)
# assert X.dtype == Y.dtype
# 下面我们使用 original = X, new = Y
original_shape = original.shape
original_dtype = original.dtype
# 计算 new
with torch.no_grad():
new = original
for module in reversed(self): # type: ignore[call-overload]
if hasattr(module, "right_inverse"):
try:
new = module.right_inverse(new)
except NotImplementedError:
pass
# 否则,或如果它抛出,我们假设 right_inverse 是恒等函数
if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence):
raise <span class="