Shortcuts

torch.jit.freeze

torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)[源代码]

冻结 ScriptModule、内联子模块和属性为常量。

冻结一个 ScriptModule 将会克隆它并尝试将克隆模块的子模块、参数和属性内联为 TorchScript IR 图中的常量。默认情况下,forward 将被保留,以及在 preserved_attrs 中指定的属性和方法。此外,任何在保留方法中修改的属性也将被保留。

冻结功能目前仅接受处于评估模式下的ScriptModules。

冻结应用了通用的优化,这将加速您的模型,无论是在哪种机器上。 为了进一步使用服务器特定的设置进行优化,请在冻结后运行optimize_for_inference

Parameters
  • mod (ScriptModule) – 要冻结的模块

  • preserved_attrs (可选[列表[str]]) – 除了forward方法外,需要保留的属性列表。 在保留方法中修改的属性也将被保留。

  • optimize_numerics (bool) – 如果True,将运行一组不严格保留数值的优化过程。优化细节可以在torch.jit.run_frozen_optimizations中找到。

Returns

冻结的 ScriptModule

示例(冻结一个带有参数的简单模块):

    def forward(self, input):
        output = self.weight.mm(input)
        output = self.linear(output)
        return output

scripted_module = torch.jit.script(MyModule(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
# 参数已被移除并作为常量内联到图中
assert len(list(frozen_module.named_parameters())) == 0
# 查看编译后的图作为Python代码
print(frozen_module.code)

示例(冻结模块并保留属性)

    def forward(self, input):
        self.modified_tensor += 1
        return input + self.modified_tensor

scripted_module = torch.jit.script(MyModule2().eval())
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
# 我们已经手动保留了 `version`,因此它仍然存在于冻结模块中并且可以被修改
assert frozen_module.version == 1
frozen_module.version = 2
# `modified_tensor` 在 forward 中被检测为被修改,因此冻结会保留它以保持模型语义
# 它以保持模型语义
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
# 现在我们已经运行了一次,下一次的结果将增加一
assert frozen_module(torch.tensor(1)) == torch.tensor(13)

注意

冻结子模块属性也是支持的: frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=[“submodule.version”])

注意

如果你不确定为什么某个属性没有被内联为常量,你可以运行dump_alias_db在frozen_module.forward.graph上,看看冻结是否检测到该属性正在被修改。

注意

因为冻结使权重成为常量并移除了模块层次结构,to 和其他 nn.Module 方法来操作设备或数据类型不再起作用。作为一种解决方法,您可以通过在 torch.jit.load 中指定 map_location 来重新映射设备,但是特定于设备的逻辑可能已经嵌入到模型中。

优云智算