torch.nn.utils.stateless 的源代码
import contextlib
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union
import torch
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
__all__ = ["functional_call"]
def _untie_named_tensors_map(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
) -> Dict[str, Tensor]:
"""
将模块中的所有绑定张量解绑到parameters_and_buffers中。
此函数返回一个新的untied_parameters_and_buffers字典,并保持原始的untied_parameters_and_buffers字典不变。它为模块中的绑定张量添加新的(缺失的)键到untied_parameters_and_buffers中。新键的值是原始parameters_and_buffers字典中用户给定的值。
如果对同一个绑定张量有多个用户给定的值,则会引发错误。
例如,如果模块有两个绑定的权重self.foo和self.tied_foo,用户传递{'foo': foo_value, ...},这将返回{'foo': foo_value, 'tied_foo': foo_value, ...}。如果用户传递{'foo': foo_value, 'tied_foo': tied_foo_value, ...},则会引发错误。如果用户传递{'foo': foo_value, 'tied_foo': foo_value, ...},则不会引发错误。
参数:
module (torch.nn.Module): 用于确定哪些张量是绑定的模块。
parameters_and_buffers (Dict[str, Tensor]): 用于重新参数化模块的{name: tensor}映射。
返回:
一个新的untied版本的parameters_and_buffers字典。
引发:
ValueError: 如果对同一个绑定张量有多个用户给定的值。
"""
# 模块中所有张量(包括绑定的)的{name: tensor}映射。
all_named_tensors: Dict[str, Tensor] = {}
all_named_tensors.update(module.named_parameters(remove_duplicate=False))
all_named_tensors.update(module.named_buffers(remove_duplicate=False))
# 模块中所有张量名称的{tensor: set(all_tied_names)}映射。
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set)
for name, tensor in all_named_tensors.items():
tensor_to_tied_names_map[tensor].add(name)
# 模块中所有张量名称的{tied_name: set(all_tied_names)}映射。
# 如果名称未绑定,则不会出现在此映射中。
tied_names_map: Dict[str, Set[str]] = {}
for tied_names in tensor_to_tied_names_map.values():
if len(tied_names) > 1:
for tied_name in tied_names:
tied_names_map[tied_name] = tied_names
# 确保用户没有为同一个绑定张量传递多个值。
given_names = set(parameters_and_buffers.keys())
given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys())
for given_name in given_names_for_tied_tensors:
tied_names = tied_names_map[given_name]
if (
# 检测是否存在多个键为同一个绑定张量。
len(tied_names.intersection(given_names_for_tied_tensors)) > 1
# 只有在用户为同一个绑定张量传递多个值时才引发错误。
# 如果所有给定的值相同,则不引发。
and len({parameters_and_buffers[tied_name] for tied_name in tied_names})
!= 1
):
raise ValueError(
f"functional_call 为键 {sorted(tied_names)} 获取了多个值,"
f"这些键是绑定的。请考虑使用 tie_weights=False"
)
# 解绑给定的命名张量映射
# 创建一个副本以避免修改原始字典
untied_parameters_and_buffers = parameters_and_buffers.copy()
for given_name in given_names_for_tied_tensors:
for tied_name in tied_names_map[given_name]:
untied_parameters_and_buffers[tied_name] = parameters_and_buffers[
given_name
]
return untied_parameters_and_buffers
@contextlib.contextmanager
def _reparametrize_module(
module: "torch.nn.Module",
parameters_and_buffers: Dict[str, Tensor],
*,
tie_weights: bool = False,
strict: bool = False,
) -> Iterator[None]:
if tie_weights:
untied_parameters_and_buffers = _untie_named_tensors_map(
module, parameters_and_buffers
)
else:
untied_parameters_and_buffers = parameters_and_buffers
accessor = NamedMemberAccessor(module)
if strict:
missing_keys,</