torch.nn.parameter 的源代码
import torch
from torch._C import _disabled_torch_function_impl
from collections import OrderedDict
# 元类,用于结合 _TensorMeta 和 Parameter 的实例检查覆盖。
class _ParameterMeta(torch._C._TensorMeta):
# 使 `isinstance(t, Parameter)` 对具有 _is_param 标志的自定义张量实例返回 True。
def __instancecheck__(self, instance):
return super().__instancecheck__(instance) or (
isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False))
[docs]class Parameter(torch.Tensor, metaclass=_ParameterMeta):
r"""一种被视为模块参数的张量。
参数是 :class:`~torch.Tensor` 的子类,当与 :class:`Module` 一起使用时,它们有一个
非常特殊的属性——当它们被分配为模块属性时,它们会自动添加到模块的参数列表中,并且会出现在例如 :meth:`~Module.parameters` 迭代器中。
分配一个张量不会有这样的效果。这是因为人们可能希望在模型中缓存一些临时状态,比如 RNN 的最后一个隐藏状态。
如果没有 :class:`Parameter` 这样的类,这些临时状态也会被注册。
参数:
data (Tensor): 参数张量。
requires_grad (bool, 可选): 参数是否需要梯度。注意,torch.no_grad() 上下文不会影响
参数创建的默认行为——在 :class:`~no_grad` 模式下,参数仍然会有 `requires_grad=True`。
有关更多详细信息,请参阅 :ref:`locally-disable-grad-doc`。默认值: `True`
"""
def __new__(cls, data=None, requires_grad=True):
if data is None:
data = torch.empty(0)
if type(data) is torch.Tensor or type(data) is Parameter:
# 为了便于维护向后兼容性,保留此路径以用于标准张量。
# 最终(tm),我们应该更改标准张量的行为以匹配。
return torch.Tensor._make_subclass(cls, data, requires_grad)
# 自定义张量的路径:在实例上设置标志以指示参数性质。
t = data.detach().requires_grad_(requires_grad)
if type(t) is not type(data):
raise RuntimeError(f"从类型为 {type(data).__name__} 的实例创建参数 "
"需要 detach() 返回相同类型的实例,但返回了 "
f"类型 {type(t).__name__}。要使用该类型作为 "
"参数,请更正由其 __torch_dispatch__() 实现定义的 detach() 语义。")
t._is_param = True
return t
# 注意:下面的 3 个方法仅适用于标准张量。自定义张量类型的参数
# 仍然被视为该自定义张量类型,并且不会为它们调用这些方法。
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Parameter containing:\n' + super().__repr__()
def __reduce_ex__(self, proto):
state = torch._utils._get_obj_state(self)
# 参见注释 [不要序列化钩子]
hooks = OrderedDict()
if not state:
return (
torch._utils._rebuild_parameter,
(self.data, self.requires_grad, hooks)
)
return (
torch._utils._rebuild_parameter_with_state,
(self.data, self.requires_grad, hooks, state)
)
__torch_function__ = _disabled_torch_function_impl
class UninitializedTensorMixin:
_allowed_methods = [
torch.Tensor.__hash__,
torch.Tensor.size,
torch.Tensor.copy_,
torch.Tensor.is_complex,
torch.Tensor.is_floating_point,
torch.Tensor.half,
torch.Tensor.float,