PyTorch XLA 中的 TorchDynamo(torch.compile) 集成¶
TorchDynamo 是一个 Python 级别的 JIT 编译器,旨在加速未经修改的 PyTorch 程序。它为编译器后端提供了一个干净的 API 来接入,其最大的特点是在 Python 字节码执行之前动态修改它。在 pytorch/xla 2.0 版本中,PyTorch/XLA 为 TorchDynamo 提供了一个实验性的后端,支持推理和训练。
XLA桥接器的工作方式是,当Dynamo识别出模型模式时,它将提供一个TorchFX图,而PyTorch/XLA将使用现有的Lazy Tensor技术来编译FX图并返回编译后的函数。
集成¶
目前支持PyTorch/XLA和Dynamo的方法是通过在torch.compile中添加backend='openxla'参数。例如:
import torch
import torch_xla.core.xla_model as xm
def add(a, b):
a_xla = a.to(xm.xla_device())
b_xla = b.to(xm.xla_device())
return a_xla + b_xla
compiled_code = torch.compile(add, backend='openxla')
print(compiled_code(torch.randn(10), torch.randn(10)))
推理¶
这里是一个使用torch.compile运行resnet18的小代码示例
import torch
import torchvision
import torch_xla.core.xla_model as xm
def eval_model(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
xla_resnet18, backend='openxla')
for data, _ in loader:
with torch.no_grad():
output = dynamo_resnet18(data)
使用torch.compile,您将看到PyTorch/XLA仅在初始化时对resent18模型进行一次跟踪,并在每次调用dynamo_resnet18时执行编译后的二进制文件,而不是每次都跟踪模型。以下是在Cloud TPU v4-8上使用torch bench比较Dynamo和Lazy的推理速度分析。
resnet18 | 2.59 resnet50 | 2.64 resnext50_32x4d | 1.91 alexnet | 1.28 mobilenet_v2 | 18.62 mnasnet1_0 | 2.68 vgg16 | 1.33 BERT_pytorch | 7.49 squeezenet1_1 | 2.29 timm_vision_transformer | 3.52 几何平均 | 3.04
训练¶
PyTorch/XLA 也支持使用 Dynamo 进行训练,但这仍处于实验阶段,我们正在与 PyTorch 编译器团队合作,不断迭代实现。以下是一个使用 torch.compile 训练 resnet18 的示例。
import torch
import torchvision
import torch_xla.core.xla_model as xm
def train_model(model, data, target, optimizer):
loss_fn = torch.nn.CrossEntropyLoss()
pred = model(data)
loss = loss_fn(pred, target)
loss.backward()
optimizer.step()
return pred
def train_model_main(loader):
device = xm.xla_device()
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
train_model, backend='openxla')
for data, target in loader:
xla_optimizer = optim.SGD(data, lr=0.1, weight_decay=1e-2)
output = dynamo_train_model(xla_resnet18, data, target, xla_optimizer)
我们预计每个训练步骤将提取并执行3个图,而不是使用Lazy tensor时的每个训练步骤1个图。以下是在Cloud TPU v4-8上使用torch bench比较Dynamo和Lazy的训练速度分析。
resnet50 | 1.33 resnet18 | 1.33 BERT_pytorch | 3.07 resnext50_32x4d | 1.43 alexnet | 1.12 mobilenet_v2 | 1.4 mnasnet1_0 | 1.19 vgg16 | 0.81 timm_vision_transformer | 1.87 squeezenet1_1 | 1.41 几何平均 | 1.41
注意: 我们运行每个模型的前向传播(fwd)和反向传播(bwd)一步,然后收集端到端(e2e)时间。在现实世界中,我们会在每次训练任务中运行多个步骤,这可以轻松隐藏执行中的追踪成本(因为它是异步的)。在这种情况下,Lazy Tensor 的性能会更好。
功能差距¶
有一个我们想要指出的差距,这阻碍了我们在更大规模的模型上使用TorchDynamo。
TorchDynamo 会将前向和后向追踪为单独的图。对于 PyTorch/XLA 来说,让 XLA 编译器将整个步骤视为一个图以最佳优化速度非常重要。每次设备执行都有一个固定的开销,这使得每个训练步骤执行多个图不太理想。
与Lazy Tensor相比,这种差距使得它在实际训练用例中效率较低,尤其是在训练过程中,追踪成本可能与执行重叠。
要点¶
TorchDynamo 为编译器后端提供了一种非常有前景的方式,可以隐藏用户的复杂性,并轻松以图形格式检索建模代码。与 PyTorch/XLA 传统的 Lazy Tensor 提取图形的方式相比,TorchDynamo 可以跳过每次迭代的图形跟踪,从而提供更好的推理响应时间。
大多数由PyTorch/XLA支持的模型,在使用新的dynamo-xla桥进行推理时,已经看到了显著的加速。我们的社区正在努力扩大支持的模型集。关于上述提到的训练功能差距,PyTorch/XLA社区非常兴奋地计划在即将到来的开发工作中改进训练差距。团队继续大力投资于TorchDynamo,并与上游合作,以完善训练故事。