Shortcuts

torch.jit.trace

torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)[源代码]

跟踪一个函数并返回一个可执行的或ScriptFunction,该函数将使用即时编译进行优化。

Tracing 适用于仅操作于 Tensor 以及列表、字典和包含 Tensor 的元组。

使用 torch.jit.tracetorch.jit.trace_module,你可以将现有的模块或 Python 函数转换为 TorchScript ScriptFunctionScriptModule。你必须提供示例输入,我们将运行该函数,记录对所有张量执行的操作。

  • 生成的独立函数的录制结果为 ScriptFunction

  • 生成的 nn.Module.forwardnn.Module 的录制结果是 ScriptModule

此模块还包含原始模块所具有的任何参数。

警告

仅当函数和模块不依赖于数据(例如,不对张量中的数据进行条件判断)且没有任何未跟踪的外部依赖项(例如,执行输入/输出或访问全局变量)时,跟踪才能正确记录这些函数和模块。跟踪仅记录在给定函数在给定张量上运行时执行的操作。因此,返回的ScriptModule将在任何输入上始终运行相同的跟踪图。当您的模块预期根据输入和/或模块状态运行不同的操作集时,这具有一些重要的含义。例如,

  • 跟踪不会记录任何控制流,如if语句或循环。 当这种控制流在你的模块中是恒定的,这是可以的, 并且它通常会内联控制流决策。但有时控制流实际上是模型本身的一部分。例如, 循环网络是一个循环,遍历输入序列的(可能是动态的)长度。

  • 在返回的 ScriptModule 中,在 训练评估 模式下行为不同的操作将始终表现得好像它处于跟踪期间的模式,无论 ScriptModule 处于哪种模式。

在这些情况下,追踪(tracing)并不合适,而 脚本化(scripting) 是一个更好的选择。如果你追踪这样的模型,可能会在后续调用模型时静默地得到不正确的结果。追踪器会在执行可能导致生成不正确追踪的操作时尝试发出警告。

Parameters

func (可调用对象torch.nn.Module) – 一个Python函数或torch.nn.Module 将使用example_inputs运行。func的参数和返回值必须是张量或包含张量的(可能嵌套的)元组。当传递一个模块时,torch.jit.trace只会运行并追踪forward方法(详见torch.jit.trace)。

Keyword Arguments
  • example_inputs (tupletorch.TensorNone, 可选) – 一个示例输入的元组,在跟踪函数时将传递给函数。 默认值: None。 这个参数或 example_kwarg_inputs 应该被指定。 生成的跟踪可以在不同类型和形状的输入上运行,假设跟踪的操作支持这些类型和形状。 example_inputs 也可以是单个张量,在这种情况下,它会自动包装在元组中。 当值为 None 时,example_kwarg_inputs 应该被指定。

  • check_trace (bool, 可选) – 检查相同的输入通过跟踪的代码是否产生相同的输出。默认值: True。如果你希望禁用此功能,例如,如果你的网络包含非确定性操作,或者你确定尽管检查器失败,网络仍然是正确的。

  • check_inputs (列表元组可选) – 用于检查跟踪结果与预期结果的输入参数列表。每个元组相当于一组输入参数,这些参数将在 example_inputs 中指定。为了获得最佳结果,请传入一组代表您期望网络看到的形状和类型空间的检查输入。如果未指定,则使用原始的 example_inputs 进行检查

  • check_tolerance (float, 可选) – 在检查器过程中使用的浮点数比较容差。这可以用于在结果由于已知原因(如算子融合)而数值上出现偏差时放宽检查器的严格性。

  • 严格 (bool, 可选) – 以严格模式运行跟踪器与否 (默认: True)。仅当你希望跟踪器记录你的可变容器类型(目前为list/dict) 并且你确定你在问题中使用的容器是一个常量结构且不会被用作 控制流(if, for)条件时,才关闭此选项。

  • example_kwarg_inputs (字典, 可选) – 此参数是一组示例输入的关键字参数包,将在跟踪函数时传递给函数。默认值:None。此参数或example_inputs应指定。字典将通过被跟踪函数的参数名称进行解包。如果字典的键与被跟踪函数的参数名称不匹配,将引发运行时异常。

Returns

如果 funcnn.Modulenn.Moduleforward 方法,trace 返回一个包含跟踪代码的 ScriptModule 对象,该对象具有一个 forward 方法。返回的 ScriptModule 将具有与原始 nn.Module 相同的子模块和参数集。如果 func 是一个独立函数,trace 返回 ScriptFunction

示例(跟踪函数):

import torch

def foo(x, y):
    return 2 * x + y

# 使用提供的输入运行 `foo` 并记录张量操作
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))

# `traced_foo` 现在可以使用 TorchScript 解释器运行,或者保存
# 并在没有 Python 的环境中加载

示例(跟踪现有模块):

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        return self.conv(x)

n = Net()
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)

# 跟踪一个特定方法并构建带有单个`forward`方法的`ScriptModule`
# 带有单个`forward`方法的`ScriptModule`
module = torch.jit.trace(n.forward, example_forward_input)

# 跟踪一个模块(隐式跟踪`forward`)并构建一个
# 带有单个`forward`方法的`ScriptModule`
module = torch.jit.trace(n, example_forward_input)
优云智算