torch.jit.trace_module¶
- torch.jit.trace_module(mod, inputs, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_inputs_is_kwarg=False, _store_inputs=True)[源代码]¶
跟踪一个模块并返回一个可执行的
ScriptModule
,该模块将使用即时编译进行优化。当一个模块传递给
torch.jit.trace
时,只有forward
方法会被运行和追踪。使用trace_module
,你可以指定一个方法名称到示例输入的字典来进行追踪(参见下面的inputs
参数)。有关跟踪的更多信息,请参阅
torch.jit.trace
。- Parameters
mod (torch.nn.Module) – 一个包含方法的
torch.nn.Module
,这些方法的名称在inputs
中指定。给定的方法将被编译为单个ScriptModule的一部分。输入 (字典) – 一个字典,包含按
mod
中的方法名称索引的样本输入。 在跟踪时,输入将被传递给名称与输入键对应的方法。{ 'forward' : example_forward_input, 'method2': example_method2_input}
- Keyword Arguments
check_trace (
bool
, 可选) – 检查相同的输入通过跟踪的代码是否产生相同的输出。默认值:True
。如果你希望禁用此功能,例如,如果你的网络包含非确定性操作,或者你确定尽管检查器失败,网络仍然是正确的。check_inputs (列表 的 字典,可选) – 用于检查跟踪结果与预期结果的输入参数列表。每个元组相当于一组输入参数,这些参数将在
inputs
中指定。为了获得最佳结果,请传入一组具有代表性的检查输入,这些输入代表了您期望网络看到的形状和类型的输入空间。如果未指定,则使用原始的inputs
进行检查check_tolerance (float, 可选) – 在检查器程序中使用的浮点数比较容差。 这可以用于在已知原因(如算子融合)导致数值结果出现偏差时,放宽检查器的严格性。
example_inputs_is_kwarg (
bool
, 可选) – 此参数指示示例输入是否为关键字参数的打包。默认值:False
。
- Returns
一个包含单个
forward
方法的ScriptModule
对象,该方法包含 traced 代码。 当func
是一个torch.nn.Module
时,返回的ScriptModule
将具有与func
相同的子模块和参数集。
示例(跟踪具有多个方法的模块):
```python import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(1, 1, 3) def forward(self, x): return self.conv(x) def weighted_kernel_sum(self, weight): return weight * self.conv.weight n = Net() example_weight = torch.rand(1, 1, 3, 3) example_forward_input = torch.rand(1, 1, 3, 3) # 跟踪特定方法并使用单个 `forward` 方法构建 `ScriptModule` module = torch.jit.trace(n.forward, example_forward_input) # 跟踪模块(隐式跟踪 `forward`)并使用单个 `forward` 方法构建 `ScriptModule` module = torch.jit.trace(n, example_forward_input) # 跟踪模块上的特定方法(在 `inputs` 中指定),构建具有 `forward` 和 `weighted_kernel_sum` 方法的 `ScriptModule` inputs = {'forward' : example_forward_input, 'weighted_kernel_sum' : example_weight} module = torch.jit.trace_module(n, inputs) ```