修补批量归一化¶
发生了什么?¶
Batch Norm 需要对 running_mean 和 running_var 进行就地更新,其大小与输入相同。
Functorch 不支持对常规张量进行就地更新,该张量接收批量张量(即
regular.add_(batched) 是不允许的)。因此,当对单个模块的输入批次进行 vmapping 时,
我们最终会遇到这个错误
如何修复¶
最好的支持方式之一是将BatchNorm切换为GroupNorm。选项1和2支持这一点
所有这些选项都假设您不需要运行统计信息。如果您使用的是一个模块,这意味着假设您不会在评估模式下使用批量归一化。如果您有一个在评估模式下使用vmap运行批量归一化的用例,请提交问题
选项 1:更改 BatchNorm¶
如果你想将BatchNorm更改为GroupNorm,只需将所有使用BatchNorm的地方替换为:
BatchNorm2d(C, G, track_running_stats=False)
这里 C 与原始 BatchNorm 中的 C 相同。G 是要将 C 分成的组数。因此,C % G == 0,作为备用方案,您可以设置 C == G,这意味着每个通道将单独处理。
如果你必须使用BatchNorm并且你自己构建了模块,你可以将模块更改为不使用运行统计数据。换句话说,在任何有BatchNorm模块的地方,将track_running_stats标志设置为False
BatchNorm2d(64, track_running_stats=False)
选项 2:torchvision 参数¶
一些torchvision模型,如resnet和regnet,可以接受一个norm_layer参数。这些参数通常默认设置为BatchNorm2d。
相反,你可以将其设置为GroupNorm。
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
这里,再次,c % g == 0 所以作为备用,设置 g = c。
如果你依赖于BatchNorm,请确保使用一个不使用运行统计数据的版本
import torchvision
from functools import partial
torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
选项 3:functorch 的修补¶
functorch 增加了一些功能,允许快速、就地修补模块,使其不使用运行状态。更改规范层更为脆弱,因此我们没有提供此功能。如果你有一个网络,希望 BatchNorm 不使用运行状态,你可以运行 replace_all_batch_norm_modules_ 来就地更新模块,使其不使用运行状态
from torch.func import replace_all_batch_norm_modules_
replace_all_batch_norm_modules_(net)
选项 4:评估模式¶
当在评估模式下运行时,running_mean 和 running_var 将不会被更新。因此,vmap 可以支持此模式
model.eval()
vmap(model)(x)
model.train()