保存使用Torch-TensorRT编译的模型¶
可以使用torch_tensorrt.save API保存使用Torch-TensorRT编译的模型。
Dynamo IR¶
Torch-TensorRT 的 ir=dynamo 编译的默认输出类型是 torch.fx.GraphModule 对象。 我们可以通过指定 output_format 标志将此对象保存为 TorchScript (torch.jit.ScriptModule) 或 ExportedProgram (torch.export.ExportedProgram) 格式。 以下是 output_format 将接受的选项
exported_program : 这是默认选项。我们首先对图模块进行转换,然后使用torch.export.save来保存模块。
torchscript : 我们通过torch.jit.trace追踪图模块,并通过torch.jit.save保存它。
a) 导出的程序¶
这是一个使用示例
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_ep is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs)
# Later, you can load it and run inference
model = torch.export.load("trt.ep").module()
model(*inputs)
b) Torchscript¶
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
# trt_gm is a torch.fx.GraphModule object
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs)
# Later, you can load it and run inference
model = torch.jit.load("trt.ts").cuda()
model(*inputs)
Torchscript IR¶
在 Torch-TensorRT 1.X 版本中,使用 Torch-TensorRT 编译和运行推理的主要方式是使用 Torchscript IR。 对于 ir=ts,这种行为在 2.X 版本中保持不变。
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)
加载模型¶
我们可以直接使用PyTorch中的torch.jit.load和torch.export.load API加载torchscript或exported_program模型。 另外,我们提供了一个轻量级的封装torch_tensorrt.load(file_path),它可以加载上述任何一种模型类型。
这是一个使用示例
import torch
import torch_tensorrt
# file_path can be trt.ep or trt.ts file obtained via saving the model (refer to the above section)
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
model = torch_tensorrt.load(<file_path>).module()
model(*inputs)