修补批量归一化¶
发生了什么?¶
Batch Norm 需要对与输入大小相同的 running_mean 和 running_var 进行就地更新。
Functorch 不支持对接收批量张量的常规张量进行就地更新(即
regular.add_(batched) 是不允许的)。因此,当对一批输入进行 vmap 操作以应用于单个模块时,
我们会遇到这个错误
如何修复¶
所有这些选项都假设您不需要运行统计数据。如果您正在使用一个模块,这意味着假设您不会在评估模式下使用批量归一化。如果您有一个用例涉及在评估模式下使用vmap运行批量归一化,请提交一个问题
选项1:更改BatchNorm¶
如果你自己构建了模块,你可以更改模块以不使用运行统计信息。换句话说,在任何有BatchNorm模块的地方,将track_running_stats标志设置为False
BatchNorm2d(64, track_running_stats=False)
选项2:torchvision参数¶
一些torchvision模型,如resnet和regnet,可以接受一个norm_layer参数。这些参数通常默认设置为BatchNorm2d。相反,你可以将其设置为不使用运行统计数据的BatchNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
选项3:functorch的修补¶
functorch 添加了一些功能,允许快速、原地修补模块。如果你有一个想要更改的网络,你可以运行 replace_all_batch_norm_modules_ 来原地更新模块,使其不使用运行统计信息。
from functorch.experimental import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)