ModuleDict¶
- class torch.nn.ModuleDict(modules=None)[源代码]¶
在字典中保存子模块。
ModuleDict
可以像普通的 Python 字典一样进行索引, 但它包含的模块会被正确注册,并且可以通过所有Module
方法可见。ModuleDict
是一个 有序 字典,遵循插入的顺序,以及
在
update()
中,合并的顺序OrderedDict
、dict
(从 Python 3.6 开始)或另一个ModuleDict
(传递给update()
的参数)。
请注意,使用其他无序映射类型(例如,Python 3.6版本之前的普通
dict
)的update()
不会保留合并映射的顺序。- Parameters
模块(可迭代对象,可选)——一个映射(字典),键为字符串,值为模块, 或者是一个包含键值对的可迭代对象,键为字符串,值为模块
示例:
class MyModule(nn.Module): def __init__(self): super().__init__() self.choices = nn.ModuleDict({ 'conv': nn.Conv2d(10, 10, 3), 'pool': nn.MaxPool2d(3) }) self.activations = nn.ModuleDict([ ['lrelu', nn.LeakyReLU()], ['prelu', nn.PReLU()] ]) def forward(self, x, choice, act): x = self.choices[choice](x) x = self.activations[act](x) return x
- update(modules)[源代码]¶
使用映射中的键值对更新
ModuleDict
,覆盖现有键。注意
如果
modules
是一个OrderedDict
,一个ModuleDict
,或者 一个键值对的迭代器,其中新元素的顺序将被保留。