Shortcuts

torch.nn.modules.container 的源代码

import warnings
from collections import OrderedDict, abc as container_abcs
from itertools import chain, islice
import operator

import torch
from .module import Module
from ..parameter import Parameter
from torch._jit_internal import _copy_to_script_wrapper

from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
from typing_extensions import Self

__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']

T = TypeVar('T', bound=Module)


# 从torch.nn.modules.module复制,为ModuleList自定义__repr__所需
def _addindent(s_, numSpaces):
    s = s_.split('\n')
    # 单行内容不做任何处理
    if len(s) == 1:
        return s_
    first = s.pop(0)
    s = [(numSpaces * ' ') + line for line in s]
    s = '\n'.join(s)
    s = first + '\n' + s
    return s


class Container(Module):

    def __init__(self, **kwargs: Any) -> None:
        super().__init__()
        # DeprecationWarning默认被忽略 <叹气>
        warnings.warn("nn.Container已弃用。其所有功能现在都已在nn.Module中实现。请改为继承该类。")
        for key, value in kwargs.items():
            self.add_module(key, value)


[docs]class Sequential(Module): r"""一个顺序容器。 模块将按照它们在构造函数中传递的顺序添加到其中。或者,可以传递一个模块的``OrderedDict``。``Sequential``的``forward()``方法接受任何输入并将其传递给第一个模块。然后它将输出“链”到输入,依次传递给每个后续模块,最后返回最后一个模块的输出。 ``Sequential``相对于手动调用一系列模块的优势在于,它允许将整个容器视为单个模块,从而对``Sequential``执行的转换将应用于它存储的每个模块(每个模块都是``Sequential``的注册子模块)。 ``Sequential``和``torch.nn.ModuleList``有什么区别?``ModuleList``正是它听起来像的那样——一个用于存储``Module``的列表!另一方面,``Sequential``中的层是以级联方式连接的。 示例:: # 使用Sequential创建一个小模型。当`model`运行时,输入将首先传递给`Conv2d(1,20,5)`。`Conv2d(1,20,5)`的输出将作为输入传递给第一个`ReLU`;第一个`ReLU`的输出将成为`Conv2d(20,64,5)`的输入。最后,`Conv2d(20,64,5)`的输出将作为输入传递给第二个`ReLU` model = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) # 使用Sequential和OrderedDict。这与上面的代码功能相同 model = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ])) """ _modules: Dict[str, Module] # type: ignore[assignment] @overload def __init__(self, *args: Module) -> None: ... @overload def __init__(self, arg: 'OrderedDict[str, Module]') -> None: ... def __init__(self, *args): super().__init__() if len(args) == 1 and isinstance(args[0], OrderedDict): for key, module in args[0].items(): self.add_module(key, module) else: for idx, module in enumerate(args): self.add_module(str(idx), module) def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var] """获取迭代器的第idx项。""" size = len(self) idx = operator.index(idx) if not -size <= idx < <