PJRT 运行时¶
PyTorch/XLA 已经从基于 TensorFlow 的 XRT 运行时迁移到了 PJRT 运行时,该运行时由 JAX 使用。
如果您在使用PJRT时遇到错误,请在GitHub上提交一个问题,并使用runtime标签。
PyTorch/XLA r2.1 中的新功能:
PJRT 在 PyTorch/XLA r2.1 中是稳定的!
公共运行时API已从
torch_xla.experimental.pjrt移至torch_xla.runtime。pjrt://初始化方法已重命名为xla://,并且它由torch_xla.distributed.xla_backend注册。为了兼容性,此版本中仍然可以使用之前的
torch_xla.experimental.*名称。
torchrun现在支持使用init_method='xla://'。通过PJRT C API为XPU和Neuron提供的新插件。
PyTorch/XLA r2.0 中的新功能:
如果您没有传入任何其他运行时配置,PJRT将默认配置。如果您继续设置XRT配置(
XRT_TPU_CONFIG),此更改不会产生影响在
libtpu中新的TPU运行时实现将性能提高了多达30%。新的
xm.rendezvous实现,可扩展到数千个TPU核心[实验性]
torch.distributed支持 TPU v2 和 v3,包括pjrt://init_method
太长不看¶
要使用PJRT预览运行时,请将
PJRT_DEVICE环境变量设置为CPU、TPU或CUDA在XRT中,所有分布式工作负载都是多进程的,每个设备一个进程。在PJRT中的TPU v2和v3上,工作负载是多进程和多线程的(4个进程,每个进程2个线程),因此您的工作负载应该是线程安全的。有关更多信息,请参见TPU v2/v3上的多线程和API指南的多进程部分。需要记住的关键差异:
要以线程安全的方式初始化模型,可以在初始化后跨副本广播参数 (
torch_xla.experimental.pjrt.broadcast_master_param) 或者从公共检查点加载每个副本的参数。对于其他随机数生成,尽可能使用
torch.Generator。 全局的torch随机数生成器不是线程安全的,即使你在所有副本中设置了相同的torch.manual_seed。要使用
torch.distributed,请导入torch_xla.experimental.pjrt_backend并使用xla://init_method。这些步骤对于GPU和TPU v4是可选的。
从XRT到PJRT的示例差异:
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
+import torch_xla.runtime as xr
def _mp_fn(index):
device = xm.xla_device()
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+ dist.init_process_group('xla', init_method='xla://')
torch.manual_seed(42)
model = nn.Linear(128, 10).to(device)
+ # Optional for TPU v4 and GPU
+ xm.broadcast_master_param(model)
model = DDP(model, gradient_as_bucket_view=True)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
for i in range(10):
data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Print mean parameters so we can confirm they're the same across replicas
print([p.mean() for p in model.parameters()])
if __name__ == '__main__':
- os.environ['XRT_TPU_CONFIG'] = 'localservice;0;localhost:51011'
- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
+ # Recommended: set PJRT_DEVICE to your local device type
+ os.environ['PJRT_DEVICE'] = 'TPU'
torch_xla.launch(_mp_fn)
好处¶
简单的运行时配置:只需将
PJRT_DEVICE设置为TPU、CPU或CUDA,然后开始使用XLA!或者,让PJRT根据您的环境自动选择设备。性能提升:减少gRPC的开销意味着更快的端到端执行。在TorchBench 2.0上,我们观察到在TPU v4上的训练时间提升了超过35%。
简单的pod执行:只需将您的代码复制到每个TPU工作节点,并使用
gcloud compute tpus tpuvm ssh --worker=all同时执行它们。更好的扩展性:消除了XRT对参数大小的限制,并支持多达2048个TPU芯片。
快速开始¶
要开始使用PJRT与PyTorch/XLA,您只需设置PJRT_DEVICE环境变量。如果您正在使用TPU v2或v3,请继续阅读以了解TPU v2、v3和v4之间的区别。
CPU¶
在任何安装了PyTorch/XLA的机器上,您可以像这样在CPU上运行我们的MNIST示例:
PJRT_DEVICE=CPU python3 xla/test/test_train_mp_mnist.py --fake_data
TPU¶
要创建一个安装了PyTorch/XLA r2.0的新TPU:
gcloud alpha compute tpus tpu-vm create $USER-pjrt --accelerator-type=v4-8 --version=tpu-vm-v4-pt-2.0 --zone=us-central2-b --project=$PROJECT
在 v4-8 上,您可以像这样运行我们的 ResNet50 示例:
git clone --depth=1 --branch r2.0 https://github.com/pytorch/xla.git
PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
默认情况下,PJRT 将使用所有 TPU 芯片。要仅使用一个 TPU 芯片,请配置
TPU_PROCESS_BOUNDS 和 TPU_VISIBLE_CHIPS:
TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_CHIPS=0 PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1
Pods¶
在TPU Pods上,使用gcloud并行在每个TPU上运行您的命令:
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="git clone --depth=1 --branch r1.13 https://github.com/pytorch/xla.git"
gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=256 --num_epochs=1"
Docker¶
你也可以使用 Docker 在预装了 PyTorch/XLA 的容器中运行你的工作负载:
export DOCKER_IMAGE=gcr.io/...
# Optional: authenticate docker if your image is in a private GCP repository
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo gcloud auth configure-docker"
# Run your workload
gcloud compute tpus tpu-vm ssh $USER-pjrt --zone=us-central2-b --project=$PROJECT --worker=all --command "sudo docker run --rm --privileged --net=host -e PJRT_DEVICE=TPU $DOCKER_IMAGE python pytorch/xla/test/test_train_mp_imagenet.py --fake_data"
请注意,docker run 需要主机的特权访问权限(--privileged)才能将TPU设备暴露给容器。目前,TPU pods上的Docker仅支持主机网络--net=host。有关更多信息,请参阅Cloud TPU文档。
GPU¶
单节点GPU训练¶
要使用GPU与PJRT,只需设置PJRT_DEVICE=CUDA并配置GPU_NUM_DEVICES为主机上的设备数量。例如:
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
你也可以使用 torchrun 来启动单节点多GPU训练。例如,
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在上面的例子中,--nnodes 表示要使用多少台机器(物理机或虚拟机)(因为我们进行的是单节点训练,所以是1)。--nproc-per-node 表示要使用多少个GPU设备。
多节点GPU训练¶
请注意,此功能仅适用于cuda 12+。类似于PyTorch使用多节点训练的方式,您可以运行以下命令:
PJRT_DEVICE=CUDA torchrun \
--nnodes=${NUMBER_GPU_VM} \
--node_rank=${CURRENT_NODE_RANK} \
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
--nnodes: 要使用多少台GPU机器。--node_rank: 当前GPU机器的索引。值可以是0, 1, …, ${NUMBER_GPU_VM}-1。--nproc_per_node: 当前机器上要使用的GPU设备数量。–rdzv_endpoint: 这是具有 node_rank==0 的 GPU 机器的端点,格式为 host:port`。``host
将是 内部 IP 地址。 端口port` 可以是机器上任何可用的端口。对于单节点训练/推理,可以省略此参数。
例如,如果你想在2台GPU机器上训练:machine_0和machine_1,在第一台GPU机器machine_0上,运行
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
在第二台GPU机器上,运行
# PJRT_DEVICE=CUDA torchrun \
--nnodes=2 \
--node_rank=1 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
上述两个命令的区别在于--node_rank和可能的--nproc_per_node,如果你想在每台机器上使用不同数量的GPU设备。其余部分完全相同。有关torchrun的更多信息,请参阅此页面。
与XRT的差异¶
虽然在大多数情况下,我们期望PJRT和XRT从最终用户的角度来看能够基本互换(尤其是在TPU v4上),但有一些细微的差异需要牢记。重要的是,XRT是围绕TPU节点架构设计的,因此即使在TPU虚拟机上,它也会始终生成一个客户端和一个服务器进程。因此,每批输入都会因为将数据序列化和反序列化以通过网络发送而产生额外的延迟。
PJRT 直接使用本地设备,没有中间服务器进程。在默认配置中,PJRT 将为每个 TPU 芯片创建一个进程,或为每个 TPU 主机创建 4 个进程。有关 TPU 架构的更多信息,请参阅 Cloud TPU 文档。
对于受限于开销的工作负载,性能提升是可能的。
在XRT下,服务器进程是唯一与TPU设备交互的进程,客户端进程无法直接访问TPU设备。在分析单主机TPU(例如v3-8或v4-8)时,通常会看到8个设备跟踪(每个TPU核心一个)。使用PJRT时,每个进程有一个芯片,该进程的配置文件将仅显示2个TPU核心。
出于同样的原因,分析在带有XRT的TPU Pods上不起作用,因为服务器进程独立于用户的模型代码运行。PJRT没有这种限制,因此可以在TPU Pod中每个进程分析2个TPU核心。
PJRT 仅支持 TPU VM 架构,我们没有计划支持 TPU Node 架构与 PJRT。
使用PJRT,运行时配置显著简化。
xla_dist不需要运行TPU Pod工作负载。相反,将您的代码复制到每个TPU主机 ([gcloud compute tpus tpu-vm scp](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/scp)) 并在每个主机上并行运行代码(例如[gcloud compute tpus tpu-vm ssh --workers=all --command="PJRT_DEVICE=TPU python run.py"](https://cloud.google.com/sdk/gcloud/reference/alpha/compute/tpus/tpu-vm/ssh))xm.rendezvous已使用 XLA 原生的集体通信重新实现,以提高大型 TPU pod 的稳定性。更多详情请见下文。
TPU v2/v3上的多线程¶
在TPU v2和v3上,分布式工作负载始终以多线程方式运行,因为每个TPU核心会暴露两个TPU核心作为设备,并且一次只能有一个进程打开一个TPU芯片。在其默认配置中,xmp.spawn会自动生成尽可能多的进程(每个TPU主机4个),并为每个进程创建两个线程(每个TPU核心一个)。
注意:在TPU v4上,每个TPU芯片表示为一个PyTorch设备,因此分布式工作负载将在4个进程中运行,每个进程只有一个线程。这与XRT的行为相同。
在大多数情况下,这不会对您现有的代码进行重大更改。
在大多数情况下,您需要做的主要更改是模型初始化。
因为torch的全局随机数生成器在线程之间共享,所以即使您在每个副本中将torch.manual_seed设置为相同的值,结果也会因线程和运行而异。为了在副本之间获得一致的参数,可以使用torch_xla.experimental.pjrt.broadcast_master_param将一个副本的参数广播到所有其他副本,或者从公共检查点加载每个副本的参数。
xm.rendezvous的更改¶
PyTorch/XLA r2.0 中的新功能
使用XRT,工作节点0运行一个网格主服务,所有工作节点上的所有进程都通过gRPC连接到该服务。实际上,我们发现,在拥有数千个芯片的TPU集群上运行单个网格主进程是不可靠的,因为工作节点0的入站连接数量过多。单个客户端进程超时可能导致失败,并迫使整个工作负载重新启动。
因此,我们使用原生的XLA集体通信重新实现了xm.rendezvous,这在大型TPU集群上更加稳定且经过充分测试。与XRT实现相比,这带来了两个新的限制:
因为有效载荷必须成为XLA图的一部分,所以在数据传输前后都会调用
xm.mark_step。在模型代码中间调用xm.rendezvous可能会强制进行不必要的编译。因为XLA不允许在部分工作者上运行集体操作,所有工作者都必须参与
rendezvous。
如果你需要xm.rendezvous的旧行为(即在不改变XLA图和/或同步一部分工作节点的情况下进行数据通信),
考虑使用
``torch.distributed.barrier` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.barrier>`_
或
``torch.distributed.all_gather_object` <https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_object>`_
与gloo进程组一起使用。如果你也在使用xla torch.distributed
后端,你可以使用torch.new_group来创建一个gloo子组。参见PyTorch文档中的这个
例子。请记住这些限制:
torch.distributed在 TPU v2/v3 上并未完全支持。只有部分使用xla后端实现的操作可用,并且在多线程环境中,gloo可能无法按预期工作。在我们的实验中,
gloo无法很好地扩展到数千个TPU芯片,因此预计这种替代方案在大规模使用PJRT时不如使用xm.rendezvous可靠。
PJRT 和 torch.distributed¶
PyTorch/XLA r2.0 中的新功能
当使用PJRT与torch.distributed和
[torch.nn.parallel.DistributedDataParallel](https://github.com/pytorch/xla/blob/master/docs/ddp.md)
时,我们强烈推荐使用新的xla:// init_method,它通过查询运行时自动
找到副本ID、世界大小和主IP。例如:
import torch
import torch_xla
import torch.distributed as dist
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
# Required for `xla://` init_method and `xla` backend
import torch_xla.distributed.xla_backend
def _all_gather(index: int):
# No need to pass in `rank` or `world_size`
dist.init_process_group('xla', init_method='xla://')
t = torch.tensor([index], dtype=torch.int32, device=xm.xla_device())
output = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(output, t)
xm.mark_step()
print(output)
if __name__ == '__main__':
torch_xla.launch(_all_gather)
注意:虽然在TPU v4上不需要xla://初始化方法,但仍然建议使用。如果使用env://,则必须将MASTER_ADDR设置为具有设备0的IP主机,这不总是工作节点0。xla://初始化方法会自动找到此IP。
注意:对于TPU v2/v3,您仍然需要导入
torch_xla.experimental.pjrt_backend,因为在
torch.distributed中对TPU v2/v3的支持仍然是实验性的。
有关在 PyTorch/XLA 上使用 DistributedDataParallel 的更多信息,请参阅
``ddp.md` <./ddp.md>`_ 在 TPU V4 上的内容。要查看一个结合使用 DDP 和 PJRT 的示例,
请在 TPU 上运行以下 示例脚本:
PJRT_DEVICE=TPU python xla/test/test_train_mp_mnist.py --ddp --pjrt_distributed --fake_data --num_epochs 1
性能¶
TorchBench 显示,与 XRT 相比,使用 PJRT 在各个任务中的平均训练时间有所改善,在 TPU v4-8 上平均提高了超过 35%。不同任务和模型类型的改善幅度差异显著,范围从 0% 到 175%。下图显示了按任务细分的改善情况:
新的TPU运行时¶
PyTorch/XLA r2.0 中的新功能
PyTorch/XLA r2.0 版本引入了对 PJRT 插件 API 的支持,用于访问 libtpu 中基于 TFRT 的新 TPU 运行时。当设置 PJRT_DEVICE=TPU 时,这现在是默认的运行时。1.13 版本中使用的基于 StreamExecutor 的旧版 TPU 运行时在 2.0 版本中仍然可以通过 PJRT_DEVICE=TPU_LEGACY 使用,但它将在未来的版本中被移除。如果您遇到仅在 TPU 上发生而不在 TPU_LEGACY 上发生的问题,请在 GitHub 上提交问题。
在大多数情况下,我们预计两种运行时的性能相似,但在某些情况下,新运行时的速度可能快达30%。下图显示了按任务的细分:
注意:此图表中显示的改进也包含在PJRT与XRT的比较中。