急切模式 + 编译API¶
在本文档中,我们将介绍如何使用PyTorch/XLA的新实验性eager模式与compile API。目标是使PyTorch/XLA的体验更接近原生PyTorch,并使开发过程更加简便。
背景¶
目前 PyTorch/XLA 默认在 LazyTensor 跟踪模式下运行。在以下代码中
import torch
import torch_xla
import torchvision
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)
# model tracing
res = model(input)
# model execution, same as `xm.mark_step`
torch_xla.sync()
实际的模型编译和设备执行发生在调用torch_xla.sync时。这种方法有多个缺点。
用户常常对框架何时进行跟踪和何时执行感到困惑。
非核心模型代码(例如数据预处理)通常会生成一些小的待执行任务,这些任务会泄漏到主图(步骤函数)中并导致重新编译。整个图的重新编译通常非常昂贵。
很难调试何时/为何发生重新编译。
为了缓解上述问题,我们希望引入带有急切和编译的新用户体验。
基本用法¶
import torch
import torch_xla
import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)
device = torch_xla.device()
model = torchvision.models.resnet18().to(device)
# Mark the function to be compiled
compiled_model = torch_xla.compile(model)
input = torch.randn(64, 3, 224, 224).to(device)
# Compilation and execution happens right away.
res = compiled_model(input)
请注意
目前用户需要手动启用eager模式,通过
torch_xla.experimental.eager_mode(True)。想要编译的代码区域应该用
torch_xla.compile包裹。
torch_xla.compile 的实现实际上非常简单,它在进入目标函数时禁用急切模式并开始跟踪。当目标函数返回时,它将调用 torch_xla.sync() 并重新启用急切模式。你可以预期使用 eager + compile API 与现有的 mark_step/sync 方法相比具有相同的性能。
推理¶
torch_xla.experimental.eager_mode(True)
compiled_model = torch.compile(model, backend="openxla")
建议使用torch.compile而不是torch_xla.compile进行推理,以减少跟踪开销。
训练¶
torch_xla.experimental.eager_mode(True)
def step_fn(model, data, target, loss_fn, optimizer):
optimizer.zero_grad()
logits = model(data)
loss = loss_fn(logits, target)
loss.backward()
optimizer.step()
return loss
step_fn = torch_xla.compile(step_fn)
在训练中,我们要求用户重构step_fn,因为通常最好将模型的前向、后向和优化器一起编译。长期目标是也使用torch.compile进行训练,但目前我们建议用户使用torch_xla.compile(出于性能原因)。
基准测试¶
我在一个v4-8芯片上使用假数据运行了一个2层仅解码器模型训练(基本上就是一个llama2),共进行了300步。以下是我观察到的数字。
| token/s | |
| Tracing mode(base line) | 147 |
| Eager mode | 65 |
| Eager + torch_xla compile | 147 |
Eager模式可以实现仅解码器模型的完全编译模型性能的约45%。我用来测试的训练器可以在这里和这里找到。请注意,Eager模式的性能非常依赖于模型。当我尝试运行resnet50时,Eager模式的性能约为编译模式的1%。我们不期望用户使用Eager模式来执行主要的训练循环。Eager模式旨在用于处理训练/推理逻辑的非核心部分(数据预处理、随机数生成等)或调试。