Shortcuts

torch.nn.parameter 的源代码

import torch
from torch._C import _disabled_torch_function_impl
from collections import OrderedDict

# 元类,用于结合 _TensorMeta 和 Parameter 的实例检查覆盖。
class _ParameterMeta(torch._C._TensorMeta):
    # 使 `isinstance(t, Parameter)` 对具有 _is_param 标志的自定义张量实例返回 True。
    def __instancecheck__(self, instance):
        return super().__instancecheck__(instance) or (
            isinstance(instance, torch.Tensor) and getattr(instance, '_is_param', False))


[docs]class Parameter(torch.Tensor, metaclass=_ParameterMeta): r"""一种被视为模块参数的张量。 参数是 :class:`~torch.Tensor` 的子类,当与 :class:`Module` 一起使用时,它们有一个 非常特殊的属性——当它们被分配为模块属性时,它们会自动添加到模块的参数列表中,并且会出现在例如 :meth:`~Module.parameters` 迭代器中。 分配一个张量不会有这样的效果。这是因为人们可能希望在模型中缓存一些临时状态,比如 RNN 的最后一个隐藏状态。 如果没有 :class:`Parameter` 这样的类,这些临时状态也会被注册。 参数: data (Tensor): 参数张量。 requires_grad (bool, 可选): 参数是否需要梯度。注意,torch.no_grad() 上下文不会影响 参数创建的默认行为——在 :class:`~no_grad` 模式下,参数仍然会有 `requires_grad=True`。 有关更多详细信息,请参阅 :ref:`locally-disable-grad-doc`。默认值: `True` """ def __new__(cls, data=None, requires_grad=True): if data is None: data = torch.empty(0) if type(data) is torch.Tensor or type(data) is Parameter: # 为了便于维护向后兼容性,保留此路径以用于标准张量。 # 最终(tm),我们应该更改标准张量的行为以匹配。 return torch.Tensor._make_subclass(cls, data, requires_grad) # 自定义张量的路径:在实例上设置标志以指示参数性质。 t = data.detach().requires_grad_(requires_grad) if type(t) is not type(data): raise RuntimeError(f"从类型为 {type(data).__name__} 的实例创建参数 " "需要 detach() 返回相同类型的实例,但返回了 " f"类型 {type(t).__name__}。要使用该类型作为 " "参数,请更正由其 __torch_dispatch__() 实现定义的 detach() 语义。") t._is_param = True return t # 注意:下面的 3 个方法仅适用于标准张量。自定义张量类型的参数 # 仍然被视为该自定义张量类型,并且不会为它们调用这些方法。 def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] else: result = type(self)(self.data.clone(memory_format=torch.preserve_format), self.requires_grad) memo[id(self)] = result return result def __repr__(self): return 'Parameter containing:\n' + super().__repr__() def __reduce_ex__(self, proto): state = torch._utils._get_obj_state(self) # 参见注释 [不要序列化钩子] hooks = OrderedDict() if not state: return ( torch._utils._rebuild_parameter, (self.data, self.requires_grad, hooks) ) return ( torch._utils._rebuild_parameter_with_state, (self.data, self.requires_grad, hooks, state) ) __torch_function__ = _disabled_torch_function_impl
class UninitializedTensorMixin: _allowed_methods = [ torch.Tensor.__hash__, torch.Tensor.size, torch.Tensor.copy_, torch.Tensor.is_complex, torch.Tensor.is_floating_point, torch.Tensor.half, torch.Tensor.float,