• Docs >
  • Compiling Exported Programs with Torch-TensorRT
Shortcuts

使用Torch-TensorRT编译导出的程序

Pytorch 2.1 引入了 torch.export API,可以将 Pytorch 程序中的图导出为 ExportedProgram 对象。Torch-TensorRT dynamo 前端编译这些 ExportedProgram 对象,并使用 TensorRT 对其进行优化。以下是 dynamo 前端的简单用法

import torch
import torch_tensorrt

model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224), dtype=torch.float32).cuda()]
exp_program = torch.export.export(model, tuple(inputs))
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) # Output is a torch.fx.GraphModule
trt_gm(*inputs)

注意

torch_tensorrt.dynamo.compile 是用户与 Torch-TensorRT dynamo 前端交互的主要 API。模型的输入类型应为 ExportedProgram(理想情况下是 torch.export.exporttorch_tensorrt.dynamo.trace 的输出(在下面的部分中讨论)),输出类型是一个 torch.fx.GraphModule 对象。

可自定义设置

用户有许多选项可以自定义他们的设置以优化TensorRT。 一些常用的选项如下:

  • inputs - 对于静态形状,这可以是torch张量或torch_tensorrt.Input对象的列表。对于动态形状,这应该是torch_tensorrt.Input对象的列表。

  • enabled_precisions - TensorRT 构建器在优化期间可以使用的精度集合。

  • truncate_long_and_double - 将长整型和双精度值分别截断为整型和浮点型。

  • torch_executed_ops - 被强制由Torch执行的运算符。

  • min_block_size - 作为TensorRT段执行所需的最小连续操作符数量。

完整的选项列表可以在这里找到

注意

我们目前在Dynamo中不支持INT精度。目前在我们的Torchscript IR中支持此功能。我们计划在下一个版本中为Dynamo实现类似的支持。

内部机制

在底层,torch_tensorrt.dynamo.compile 对图执行以下操作。

  • 降低 - 应用降低传递来添加/删除操作符以实现最佳转换。

  • 分区 - 根据min_block_sizetorch_executed_ops字段将图划分为Pytorch和TensorRT段。

  • 转换 - 在此阶段,Pytorch操作被转换为TensorRT操作。

  • 优化 - 转换后,我们构建TensorRT引擎并将其嵌入到pytorch图中。

追踪

torch_tensorrt.dynamo.trace 可用于追踪 Pytorch 图并生成 ExportedProgram。 这内部会执行一些操作符的分解以进行下游优化。 然后可以将 ExportedProgramtorch_tensorrt.dynamo.compile API 一起使用。 如果你的模型中有动态输入形状,你可以使用这个 torch_tensorrt.dynamo.trace 来导出具有动态形状的模型。 或者,你也可以直接使用 torch.export with constraints

import torch
import torch_tensorrt

inputs = [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
                              opt_shape=(4, 3, 224, 224),
                              max_shape=(8, 3, 224, 224),
                              dtype=torch.float32)]
model = MyModel().eval()
exp_program = torch_tensorrt.dynamo.trace(model, inputs)