Shortcuts

torch.nn.utils.stateless 的源代码

import contextlib
import warnings
from collections import defaultdict
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union

import torch
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor

__all__ = ["functional_call"]


def _untie_named_tensors_map(
    module: "torch.nn.Module",
    parameters_and_buffers: Dict[str, Tensor],
) -> Dict[str, Tensor]:
    """
    将模块中的所有绑定张量解绑到parameters_and_buffers中。

    此函数返回一个新的untied_parameters_and_buffers字典,并保持原始的untied_parameters_and_buffers字典不变。它为模块中的绑定张量添加新的(缺失的)键到untied_parameters_and_buffers中。新键的值是原始parameters_and_buffers字典中用户给定的值。

    如果对同一个绑定张量有多个用户给定的值,则会引发错误。

    例如,如果模块有两个绑定的权重self.foo和self.tied_foo,用户传递{'foo': foo_value, ...},这将返回{'foo': foo_value, 'tied_foo': foo_value, ...}。如果用户传递{'foo': foo_value, 'tied_foo': tied_foo_value, ...},则会引发错误。如果用户传递{'foo': foo_value, 'tied_foo': foo_value, ...},则不会引发错误。

    参数:
        module (torch.nn.Module): 用于确定哪些张量是绑定的模块。
        parameters_and_buffers (Dict[str, Tensor]): 用于重新参数化模块的{name: tensor}映射。

    返回:
        一个新的untied版本的parameters_and_buffers字典。

    引发:
        ValueError: 如果对同一个绑定张量有多个用户给定的值。
    """
    # 模块中所有张量(包括绑定的)的{name: tensor}映射。
    all_named_tensors: Dict[str, Tensor] = {}
    all_named_tensors.update(module.named_parameters(remove_duplicate=False))
    all_named_tensors.update(module.named_buffers(remove_duplicate=False))

    # 模块中所有张量名称的{tensor: set(all_tied_names)}映射。
    tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set)
    for name, tensor in all_named_tensors.items():
        tensor_to_tied_names_map[tensor].add(name)

    # 模块中所有张量名称的{tied_name: set(all_tied_names)}映射。
    # 如果名称未绑定,则不会出现在此映射中。
    tied_names_map: Dict[str, Set[str]] = {}
    for tied_names in tensor_to_tied_names_map.values():
        if len(tied_names) > 1:
            for tied_name in tied_names:
                tied_names_map[tied_name] = tied_names

    # 确保用户没有为同一个绑定张量传递多个值。
    given_names = set(parameters_and_buffers.keys())
    given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys())
    for given_name in given_names_for_tied_tensors:
        tied_names = tied_names_map[given_name]
        if (
            # 检测是否存在多个键为同一个绑定张量。
            len(tied_names.intersection(given_names_for_tied_tensors)) > 1
            # 只有在用户为同一个绑定张量传递多个值时才引发错误。
            # 如果所有给定的值相同,则不引发。
            and len({parameters_and_buffers[tied_name] for tied_name in tied_names})
            != 1
        ):
            raise ValueError(
                f"functional_call 为键 {sorted(tied_names)} 获取了多个值,"
                f"这些键是绑定的。请考虑使用 tie_weights=False"
            )

    # 解绑给定的命名张量映射
    # 创建一个副本以避免修改原始字典
    untied_parameters_and_buffers = parameters_and_buffers.copy()
    for given_name in given_names_for_tied_tensors:
        for tied_name in tied_names_map[given_name]:
            untied_parameters_and_buffers[tied_name] = parameters_and_buffers[
                given_name
            ]
    return untied_parameters_and_buffers


@contextlib.contextmanager
def _reparametrize_module(
    module: "torch.nn.Module",
    parameters_and_buffers: Dict[str, Tensor],
    *,
    tie_weights: bool = False,
    strict: bool = False,
) -> Iterator[None]:
    if tie_weights:
        untied_parameters_and_buffers = _untie_named_tensors_map(
            module, parameters_and_buffers
        )
    else:
        untied_parameters_and_buffers = parameters_and_buffers

    accessor = NamedMemberAccessor(module)
    if strict:
        missing_keys,</