torch.nn.modules.container 的源代码
import warnings
from collections import OrderedDict, abc as container_abcs
from itertools import chain, islice
import operator
import torch
from .module import Module
from ..parameter import Parameter
from torch._jit_internal import _copy_to_script_wrapper
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
from typing_extensions import Self
__all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
T = TypeVar('T', bound=Module)
# 从torch.nn.modules.module复制,为ModuleList自定义__repr__所需
def _addindent(s_, numSpaces):
s = s_.split('\n')
# 单行内容不做任何处理
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
class Container(Module):
def __init__(self, **kwargs: Any) -> None:
super().__init__()
# DeprecationWarning默认被忽略 <叹气>
warnings.warn("nn.Container已弃用。其所有功能现在都已在nn.Module中实现。请改为继承该类。")
for key, value in kwargs.items():
self.add_module(key, value)
[docs]class Sequential(Module):
r"""一个顺序容器。
模块将按照它们在构造函数中传递的顺序添加到其中。或者,可以传递一个模块的``OrderedDict``。``Sequential``的``forward()``方法接受任何输入并将其传递给第一个模块。然后它将输出“链”到输入,依次传递给每个后续模块,最后返回最后一个模块的输出。
``Sequential``相对于手动调用一系列模块的优势在于,它允许将整个容器视为单个模块,从而对``Sequential``执行的转换将应用于它存储的每个模块(每个模块都是``Sequential``的注册子模块)。
``Sequential``和``torch.nn.ModuleList``有什么区别?``ModuleList``正是它听起来像的那样——一个用于存储``Module``的列表!另一方面,``Sequential``中的层是以级联方式连接的。
示例::
# 使用Sequential创建一个小模型。当`model`运行时,输入将首先传递给`Conv2d(1,20,5)`。`Conv2d(1,20,5)`的输出将作为输入传递给第一个`ReLU`;第一个`ReLU`的输出将成为`Conv2d(20,64,5)`的输入。最后,`Conv2d(20,64,5)`的输出将作为输入传递给第二个`ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# 使用Sequential和OrderedDict。这与上面的代码功能相同
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
_modules: Dict[str, Module] # type: ignore[assignment]
@overload
def __init__(self, *args: Module) -> None:
...
@overload
def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
...
def __init__(self, *args):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
"""获取迭代器的第idx项。"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < <