• Docs >
  • Using Torch-TensorRT TorchScript Frontend Directly From PyTorch
Shortcuts

直接从PyTorch使用Torch-TensorRT TorchScript前端

您现在可以直接从PyTorch API访问TensorRT。使用此功能的过程与在Python中使用Torch-TensorRT中描述的编译工作流非常相似。

首先将torch_tensorrt加载到您的应用程序中。

import torch
import torch_tensorrt

然后给定一个TorchScript模块,你可以使用torch._C._jit_to_backend("tensorrt", ...) API将其编译为TensorRT。

import torchvision.models as models

model = models.mobilenet_v2(pretrained=True)
script_model = torch.jit.script(model)

与Torch-TensorRT中的compile API不同,它假设您正在尝试编译模块的forward函数,或者convert_method_to_trt_engine将指定函数转换为TensorRT引擎,后端API将接受一个字典,该字典将函数名称映射到编译规范对象,这些对象包装了您将提供给compile的相同类型的字典。有关编译规范字典的更多信息,请查看Torch-TensorRT的TensorRTCompileSpec API文档。

spec = {
    "forward": torch_tensorrt.ts.TensorRTCompileSpec(
        **{
            "inputs": [torch_tensorrt.Input([1, 3, 300, 300])],
            "enabled_precisions": {torch.float, torch.half},
            "refit": False,
            "debug": False,
            "device": {
                "device_type": torch_tensorrt.DeviceType.GPU,
                "gpu_id": 0,
                "dla_core": 0,
                "allow_gpu_fallback": True,
            },
            "capability": torch_tensorrt.EngineCapability.default,
            "num_avg_timing_iters": 1,
        }
    )
}

现在要使用Torch-TensorRT进行编译,请将目标模块对象和规范字典提供给torch._C._jit_to_backend("tensorrt", ...)

trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)

要显式调用您想要运行的方法的函数(与在标准PyTorch中可以直接调用模块本身的方式不同)

input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
print(trt_model.forward(input))