Shortcuts

在XLA设备上使用PyTorch

PyTorch 在 XLA 设备(如 TPU)上运行,使用 torch_xla 包。本文档描述了 如何在这些设备上运行您的模型。

创建一个XLA张量

PyTorch/XLA 向 PyTorch 添加了一个新的 xla 设备类型。这种设备类型的工作方式与其他 PyTorch 设备类型相同。例如,以下是创建和打印 XLA 张量的方法:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

这段代码应该看起来很熟悉。PyTorch/XLA 使用与常规 PyTorch 相同的接口,并添加了一些额外的功能。导入 torch_xla 会初始化 PyTorch/XLA,而 xm.xla_device() 会返回当前的 XLA 设备。根据你的环境,这可能是 CPU 或 TPU。

XLA 张量是 PyTorch 张量

PyTorch 操作可以在 XLA 张量上执行,就像在 CPU 或 CUDA 张量上一样。

例如,XLA张量可以相加:

t0 = torch.randn(2, 2, device=xm.xla_device())
t1 = torch.randn(2, 2, device=xm.xla_device())
print(t0 + t1)

或者矩阵相乘:

print(t0.mm(t1))

或与神经网络模块一起使用:

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20).to(xm.xla_device())
l_out = linear(l_in)
print(l_out)

与其他设备类型一样,XLA张量只能与同一设备上的其他XLA张量一起工作。因此,像这样的代码

l_in = torch.randn(10, device=xm.xla_device())
linear = torch.nn.Linear(10, 20)
l_out = linear(l_in)
print(l_out)
# Input tensor is not an XLA tensor: torch.FloatTensor

会抛出错误,因为torch.nn.Linear模块在CPU上。

在XLA设备上运行模型

构建一个新的PyTorch网络或将现有的网络转换为在XLA设备上运行只需要几行XLA特定的代码。以下代码片段突出显示了在单个设备和多个设备上使用XLA多进程运行时这些代码行。

在单个XLA设备上运行

以下代码片段展示了在单个XLA设备上进行网络训练的过程:

import torch_xla.core.xla_model as xm

device = xm.xla_device()
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

for data, target in train_loader:
  optimizer.zero_grad()
  data = data.to(device)
  target = target.to(device)
  output = model(data)
  loss = loss_fn(output, target)
  loss.backward()

  optimizer.step()
  xm.mark_step()

这段代码展示了将模型切换到XLA上运行是多么容易。模型定义、数据加载器、优化器和训练循环可以在任何设备上工作。唯一的XLA特定代码是几行获取XLA设备并标记步骤的代码。在每个训练迭代结束时调用xm.mark_step()会导致XLA执行其当前图并更新模型的参数。有关XLA如何创建图和运行操作的更多信息,请参见XLA Tensor Deep Dive

在多XLA设备上使用多进程运行

PyTorch/XLA 通过在多台 XLA 设备上运行,使得加速训练变得容易。以下代码片段展示了如何操作:

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def _mp_fn(index):
  device = xm.xla_device()
  mp_device_loader = pl.MpDeviceLoader(train_loader, device)

  model = MNIST().train().to(device)
  loss_fn = nn.NLLLoss()
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

  for data, target in mp_device_loader:
    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()
    xm.optimizer_step(optimizer)

if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

这个多设备代码片段与前一个单设备代码片段有三个不同之处。让我们逐一来看。

  • torch_xla.launch()

    • 创建每个运行XLA设备的进程。

    • 此函数是多线程生成的包装器,允许用户也使用torchrun命令行运行脚本。每个进程只能访问分配给当前进程的设备。例如,在TPU v4-8上,将生成4个进程,每个进程将拥有一个TPU设备。

    • 请注意,如果您在每个进程上打印xm.xla_device(),您将在所有设备上看到xla:0。这是因为每个进程只能看到一个设备。这并不意味着多进程没有起作用。唯一的执行是在TPU v2和TPU v3上使用PJRT运行时,因为会有#devices/2个进程,每个进程将有2个线程(查看此文档以获取更多详细信息)。

  • MpDeviceLoader

    • 将训练数据加载到每个设备上。

    • MpDeviceLoader 可以包装在 torch 数据加载器上。它可以预加载数据到设备,并将数据加载与设备执行重叠,以提高性能。

    • MpDeviceLoader 也会为你每 batches_per_execution(默认为1)批次生成时调用 xm.mark_step

  • xm.optimizer_step(optimizer)

    • 整合设备之间的梯度并发出XLA设备步骤计算。

    • 它基本上是一个all_reduce_gradients + optimizer.step() + mark_step,并返回被减少的损失。

模型定义、优化器定义和训练循环保持不变。

注意: 需要注意的是,在使用多进程时,用户只能从torch_xla.launch()的目标函数(或调用堆栈中任何以torch_xla.launch()为父函数的函数)内部开始检索和访问XLA设备。

查看 完整的多进程示例 以了解更多关于使用多进程在多个XLA设备上训练网络的信息。

在TPU Pods上运行

不同加速器的多主机设置可能非常不同。本文档将讨论多主机训练中与设备无关的部分,并以TPU + PJRT运行时(目前在1.13和2.x版本中可用)为例。

在开始之前,请查看我们的用户指南这里,它将解释一些Google Cloud的基础知识,如如何使用gcloud命令以及如何设置您的项目。您还可以查看这里以获取所有Cloud TPU的操作指南。本文档将重点介绍设置的PyTorch/XLA视角。

假设您在上面的部分中有一个train_mnist_xla.py的mnist示例。如果是单主机多设备训练,您将通过ssh连接到TPUVM并运行如下命令

PJRT_DEVICE=TPU python3 train_mnist_xla.py

现在为了在TPU v4-16(有2个主机,每个主机有4个TPU设备)上运行相同的模型,你需要

  • 确保每个主机都能访问训练脚本和训练数据。这通常通过使用gcloud scp命令或gcloud ssh命令将训练脚本复制到所有主机来完成。

  • 同时在所有主机上运行相同的训练命令。

gcloud alpha compute tpus tpu-vm ssh $USER-pjrt --zone=$ZONE --project=$PROJECT --worker=all --command="PJRT_DEVICE=TPU python3 train_mnist_xla.py"

上面的 gcloud ssh 命令将同时 SSH 到 TPUVM Pod 中的所有主机并同时运行相同的命令。

注意: 你需要在TPUVM虚拟机之外运行上述gcloud命令。

多进程训练和多主机训练的模型代码和训练脚本是相同的。PyTorch/XLA 和底层基础设施将确保每个设备都了解全局拓扑以及每个设备的本地和全局序号。跨设备通信将在所有设备之间进行,而不仅仅是本地设备。

有关PJRT运行时及其在pod上运行的更多详细信息,请参阅此文档。有关PyTorch/XLA和TPU pod的更多信息以及使用假数据在TPU pod上运行resnet50的完整指南,请参阅此指南

XLA 张量深度解析

使用XLA张量和设备只需要更改几行代码。但即使XLA张量的行为与CPU和CUDA张量非常相似,它们的内部结构是不同的。本节描述了XLA张量的独特之处。

XLA 张量是惰性的

CPU和CUDA张量会立即或急切地启动操作。而XLA张量则是惰性的。它们会将操作记录在图中,直到需要结果为止。像这样延迟执行可以让XLA对其进行优化。例如,多个独立操作的图可能会融合成一个单一的优化操作。

懒执行通常对调用者是不可见的。PyTorch/XLA 自动构建图,将它们发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在采取优化器步骤时插入屏障会显式同步 CPU 和 XLA 设备。有关我们懒张量设计的更多信息,您可以阅读 这篇论文

内存布局

XLA张量的内部数据表示对用户是不透明的。它们不暴露其存储,并且总是看起来是连续的,与CPU和CUDA张量不同。这使得XLA能够调整张量的内存布局以获得更好的性能。

将XLA张量移动到CPU和从CPU移出

XLA 张量可以从 CPU 移动到 XLA 设备,也可以从 XLA 设备移动到 CPU。如果移动了一个视图,那么它所查看的数据也会被复制到另一个设备上,并且视图关系不会被保留。换句话说,一旦数据被复制到另一个设备,它就与之前的设备或该设备上的任何张量没有关系了。再次强调,根据代码的操作方式,理解和适应这种转换可能是重要的。

保存和加载XLA张量

XLA 张量在保存之前应移动到 CPU,如下面的代码片段所示:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)

tensors = (t0.cpu(), t1.cpu())

torch.save(tensors, 'tensors.pt')

tensors = torch.load('tensors.pt')

t0 = tensors[0].to(device)
t1 = tensors[1].to(device)

这允许你将加载的张量放在任何可用的设备上,而不仅仅是它们初始化的设备。

根据上述关于将XLA张量移动到CPU的说明,在处理视图时必须小心。建议不要保存视图,而是在张量加载并移动到目标设备后重新创建它们。

提供了一个实用API,通过处理之前将其移动到CPU来保存数据:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

xm.save(model.state_dict(), path)

在多个设备的情况下,上述API将仅保存主设备序号(0)的数据。

在内存有限而模型参数较大的情况下,提供了一个API来减少主机上的内存占用:

import torch_xla.utils.serialization as xser

xser.save(model.state_dict(), path)

此API一次将XLA张流传输到CPU,减少了使用的主机内存量,但需要匹配的加载API来恢复:

import torch_xla.utils.serialization as xser

state_dict = xser.load(path)
model.load_state_dict(state_dict)

直接保存XLA张量是可能的,但不推荐。XLA张量总是被加载回它们被保存的设备,如果该设备不可用,加载将失败。PyTorch/XLA,像所有PyTorch一样,正在积极开发中,这种行为将来可能会改变。

编译缓存

XLA编译器将跟踪的HLO转换为可在设备上运行的可执行文件。编译可能耗时,并且在HLO在执行之间不发生变化的情况下,编译结果可以持久化到磁盘以供重用,从而显著减少开发迭代时间。

请注意,如果在执行之间HLO发生变化,仍然会发生重新编译。

这目前是一个实验性的选择加入API,必须在执行任何计算之前激活。初始化是通过initialize_cache API完成的:

import torch_xla.runtime as xr
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)

这将在指定路径初始化一个持久编译缓存。readonly 参数可用于控制工作进程是否能够写入缓存,这在为SPMD工作负载使用共享缓存挂载时非常有用。

如果你想在多进程训练中使用持久编译缓存(使用torch_xla.launchxmp.spawn),你应该为不同的进程使用不同的路径。

def _mp_fn(index):
  # cache init needs to happens inside the mp_fn.
  xr.initialize_cache(f'/tmp/xla_cache_{index}', readonly=False)
  ....

if __name__ == '__main__':
  torch_xla.launch(_mp_fn, args=())

如果您没有访问index的权限,您可以使用xr.global_ordinal()。查看可运行的示例这里

进一步阅读

更多文档可在 PyTorch/XLA 仓库中找到。更多在 TPU 上运行网络的示例可在 这里找到。

PyTorch/XLA API

torch_xla

torch_xla.device(index: Optional[int] = None) device[source]

返回一个XLA设备的给定实例。

如果启用了SPMD,则返回一个虚拟设备,该设备封装了此进程可用的所有设备。

Parameters

index – 要返回的XLA设备的索引。对应于torch_xla.devices()中的索引。

Returns

一个XLA torch.device

torch_xla.devices() List[device][source]

返回当前进程中所有可用的设备。

Returns

XLA torch.devices 的列表。

torch_xla.device_count() int[source]

返回当前进程中可寻址设备的数量。

torch_xla.sync(wait: bool = False)[source]

启动所有挂起的图操作。

Parameters

wait (bool) – 是否阻塞当前进程直到执行完成。

torch_xla.compile(f: Optional[Callable] = None, full_graph: Optional[bool] = False, name: Optional[str] = None, num_different_graphs_allowed: Optional[int] = None)[source]

使用torch_xla的LazyTensor跟踪模式优化给定的模型/函数。 PyTorch/XLA将使用给定的输入跟踪给定的函数,然后生成 图表来表示在此函数内发生的pytorch操作。这个 图表将由XLA编译并在加速器上执行(由 张量的设备决定)。对于函数的编译区域,将禁用急切模式。

Parameters
  • model (Callable) – 要优化的模块/函数,如果未传递此函数,它将作为上下文管理器。

  • full_graph (Optional[bool]) – 是否应生成单个图。如果设置为True 并且将生成多个图,torch_xla将抛出带有调试信息的错误 并退出。

  • name (可选[name]) – 编译程序的名称。如果未指定,将使用函数 f 的名称。此名称将用于 PT_XLA_DEBUG 消息以及 HLO/IR 转储文件中。

  • num_different_graphs_allowed (可选[python:int]) – 允许的给定模型/函数的不同跟踪图的数量。如果超过此限制,将引发错误。

示例:

# usage 1
@torch_xla.compile()
def foo(x):
  return torch.sin(x) + torch.cos(x)

def foo2(x):
  return torch.sin(x) + torch.cos(x)
# usage 2
compiled_foo2 = torch_xla.compile(foo2)

# usage 3
with torch_xla.compile():
  res = foo2(x)
torch_xla.manual_seed(seed, device=None)[source]

为当前XLA设备设置生成随机数的种子。

Parameters
  • seed (python:integer) – 要设置的状态。

  • device (torch.device, optional) – 需要设置RNG状态的设备。 如果缺失,将设置默认设备种子。

运行时

torch_xla.runtime.device_type() Optional[str][source]

返回当前的PjRt设备类型。

如果没有配置设备,则选择默认设备

Returns

设备的字符串表示。

torch_xla.runtime.local_process_count() int[source]

返回在此主机上运行的进程数。

torch_xla.runtime.local_device_count() int[source]

返回此主机上的设备总数。

假设每个进程具有相同数量的可寻址设备。

torch_xla.runtime.addressable_device_count() int[source]

返回此进程可见的设备数量。

torch_xla.runtime.global_device_count() int[source]

返回所有进程/主机上的设备总数。

torch_xla.runtime.global_runtime_device_count() int[source]

返回所有进程/主机上的运行时设备总数,特别适用于SPMD。

torch_xla.runtime.world_size() int[source]

返回参与作业的进程总数。

torch_xla.runtime.global_ordinal() int[source]

返回此线程在所有进程中的全局序号。

全局序号在范围 [0, global_device_count) 内。全局序号不保证与TPU工作器ID有任何可预测的关系,也不保证在每个主机上是连续的。

torch_xla.runtime.local_ordinal() int[source]

返回此主机内此线程的本地序号。

本地序号在范围 [0, local_device_count) 内。

torch_xla.runtime.get_master_ip() str[source]

检索运行时的主工作节点IP。这将调用特定于后端的发现API。

Returns

主工作节点的IP地址作为字符串。

torch_xla.runtime.use_spmd(auto: Optional[bool] = False)[source]

启用SPMD模式的API。这是启用SPMD的推荐方式。

如果某些张量已经在非SPMD设备上初始化,这将强制进入SPMD模式。这意味着这些张量将在设备之间进行复制。

Parameters

auto (bool) – 是否启用自动分片。阅读 https://github.com/pytorch/xla/blob/master/docs/spmd_advanced.md#auto-sharding 了解更多详情

torch_xla.runtime.is_spmd()[source]

返回是否设置了SPMD以进行执行。

torch_xla.runtime.initialize_cache(path: str, readonly: bool = False)[source]

初始化持久化编译缓存。此API必须在执行任何计算之前调用。

Parameters
  • path (str) – 存储持久缓存的路径。

  • readonly (bool) – 此工作线程是否应具有对缓存的写访问权限。

xla_model

torch_xla.core.xla_model.xla_device(n: Optional[int] = None, devkind: Optional[str] = None) device[source]

返回一个XLA设备的给定实例。

Parameters
  • n (python:int, 可选) – 要返回的特定实例(序号)。如果指定了,将返回特定的XLA设备实例。否则将返回devkind的第一个设备。

  • devkind (string..., optional) – 如果指定,设备类型如 TPUCUDACPU 或自定义 PJRT 设备。已弃用。

Returns

一个带有请求实例的torch.device

torch_xla.core.xla_model.xla_device_hw(device: Union[str, device]) str[source]

返回给定设备的硬件类型。

Parameters

device (stringtorch.device) – 将映射到实际设备的xla设备。

Returns

给定设备的硬件类型的字符串表示。

torch_xla.core.xla_model.is_master_ordinal(local: bool = True) bool[source]

检查当前进程是否是主序数(0)。

Parameters

local (bool) – 是否应检查本地或全局的主序数。 在多主机复制的情况下,只有一个全局主序数 (主机 0,设备 0),而有 NUM_HOSTS 个本地主序数。 默认值:True

Returns

一个布尔值,指示当前进程是否是主序数。

torch_xla.core.xla_model.all_reduce(reduce_type: str, inputs: Union[Tensor, List[Tensor]], scale: float = 1.0, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Union[Tensor, List[Tensor]][source]

对输入张量执行原地归约操作。

Parameters
  • reduce_type (string) – 其中之一为 xm.REDUCE_SUM, xm.REDUCE_MUL, xm.REDUCE_AND, xm.REDUCE_OR, xm.REDUCE_MINxm.REDUCE_MAX.

  • inputs – 可以是单个torch.Tensor或一个torch.Tensor列表,用于执行全归约操作。

  • scale (python:float) – 在reduce之后应用的默认缩放值。 默认值: 1.0

  • groups (list, optional) –

    一个列表的列表,表示all_reduce()操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含[0, 1, 2, 3]副本,另一个包含[4, 5, 6, 7]副本。如果None,则只有一个包含所有副本的组。

  • pin_layout (bool, optional) – 是否为此通信操作固定布局。 当参与通信的每个进程的程序略有不同时,布局固定可以防止潜在的数据损坏,但它可能会导致某些xla编译失败。当您看到类似“HloModule具有混合布局约束”的错误消息时,请取消固定布局。

Returns

如果传递了一个单一的torch.Tensor,返回值是一个torch.Tensor,它保存了减少后的值(跨副本)。如果传递了一个列表/元组,此函数会对输入张量执行原地全归约操作,并返回列表/元组本身。

torch_xla.core.xla_model.all_gather(value: Tensor, dim: int = 0, groups: Optional[List[List[int]]] = None, output: Optional[Tensor] = None, pin_layout: bool = True) Tensor[source]

沿给定维度执行全收集操作。

Parameters
  • value (torch.Tensor) – 输入张量。

  • dim (python:int) – 聚集维度。 默认值:0

  • groups (list, optional) –

    一个列表的列表,表示all_gather()操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含[0, 1, 2, 3]副本,另一个包含[4, 5, 6, 7]副本。如果None,则只有一个包含所有副本的组。

  • 输出 (torch.Tensor) – 可选的输出张量。

  • pin_layout (bool, optional) – 是否为此通信操作固定布局。 当参与通信的每个进程的程序略有不同时,布局固定可以防止潜在的数据损坏,但它可能会导致某些xla编译失败。当您看到类似“HloModule具有混合布局约束”的错误消息时,请取消固定布局。

Returns

一个在dim维度中包含所有参与副本的值的张量。

torch_xla.core.xla_model.all_to_all(value: Tensor, split_dimension: int, concat_dimension: int, split_count: int, groups: Optional[List[List[int]]] = None, pin_layout: bool = True) Tensor[source]

对输入张量执行XLA AllToAll()操作。

参见:https://www.tensorflow.org/xla/operation_semantics#alltoall

Parameters
  • value (torch.Tensor) – 输入张量。

  • split_dimension (python:int) – 应该在其上进行分割的维度。

  • concat_dimension (python:int) – 应该在其上进行连接的维度。

  • split_count (python:int) – 分割计数。

  • groups (list, optional) –

    一个列表的列表,表示all_reduce()操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含[0, 1, 2, 3]副本,另一个包含[4, 5, 6, 7]副本。如果None,则只有一个包含所有副本的组。

  • pin_layout (bool, optional) – 是否为此通信操作固定布局。 当参与通信的每个进程的程序略有不同时,布局固定可以防止潜在的数据损坏,但它可能会导致某些xla编译失败。当您看到类似“HloModule具有混合布局约束”的错误消息时,请取消固定布局。

Returns

操作all_to_all()的结果torch.Tensor

torch_xla.core.xla_model.add_step_closure(closure: Callable[[...], Any], args: Tuple[Any] = (), run_async: bool = False)[source]

将一个闭包添加到要在步骤结束时运行的列表中。

在模型训练过程中,很多时候需要打印/报告(打印到控制台、发布到tensorboard等)信息,这些信息需要检查中间张量的内容。 在模型代码的不同点检查不同张量的内容需要多次执行,通常会导致性能问题。 添加一个步骤闭包将确保它会在屏障之后运行,此时所有活动的张量已经物化为设备数据。 活动的张量将包括闭包参数捕获的那些张量。 因此,使用add_step_closure()将确保即使多个闭包被排队,需要检查多个张量,也只会执行一次执行。 步骤闭包将按照它们被排队的顺序依次运行。 请注意,即使使用此API优化了执行,也建议每N步限制一次打印/报告事件。

Parameters
  • closure (callable) – 要调用的函数。

  • args (tuple) – 传递给闭包的参数。

  • run_async – 如果为True,则异步运行闭包。

torch_xla.core.xla_model.wait_device_ops(devices: List[str] = [])[source]

等待给定设备上的所有异步操作完成。

Parameters

devices (string..., optional) – 需要等待其异步操作的设备。如果为空,将等待所有本地设备。

torch_xla.core.xla_model.optimizer_step(optimizer: Optimizer, barrier: bool = False, optimizer_args: Dict = {}, groups: Optional[List[List[int]]] = None, pin_layout: bool = True)[source]

运行提供的优化器步骤并同步所有设备上的梯度。

Parameters
  • optimizer (torch.Optimizer) – 需要调用其step()函数的torch.Optimizer实例。step()函数将使用optimizer_args命名参数进行调用。

  • barrier (bool, optional) – 是否应在此API中发出XLA张量屏障。如果使用PyTorch XLA的ParallelLoaderDataParallel支持,则不需要此操作,因为屏障将由XLA数据加载器迭代器的next()调用发出。 默认值:False

  • optimizer_args (dict, optional) – 用于optimizer.step()调用的命名参数字典。

  • groups (list, optional) –

    一个列表的列表,表示all_reduce()操作的副本组。示例:[[0, 1, 2, 3], [4, 5, 6, 7]]

    定义了两个组,一个包含[0, 1, 2, 3]副本,另一个包含[4, 5, 6, 7]副本。如果None,则只有一个包含所有副本的组。

  • pin_layout (bool, optional) – 是否在减少梯度时固定布局。 详情请参见 xm.all_reduce

Returns

optimizer.step()调用返回的值相同。

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.optimizer_step(self.optimizer)
torch_xla.core.xla_model.save(data: Any, file_or_path: Union[str, TextIO], master_only: bool = True, global_master: bool = False)[source]

将输入数据保存到文件中。

保存的数据在保存之前会被转移到PyTorch CPU设备上,因此后续的torch.load()将加载CPU数据。 在处理视图时必须小心。建议不要保存视图,而是在张量加载并移动到目标设备后重新创建它们。

Parameters
  • data – 要保存的输入数据。可以是任何嵌套组合的Python对象(列表、元组、集合、字典等)。

  • file_or_path – 数据保存操作的目标。可以是文件路径或Python文件对象。如果master_onlyFalse,路径或文件对象必须指向不同的目标,否则来自同一主机的所有写入将相互覆盖。

  • master_only (bool, optional) – 是否只有主设备应该保存数据。如果为False,file_or_path参数应为参与复制的每个序数提供不同的文件或路径,否则同一主机上的所有副本将写入相同的位置。默认值:True

  • global_master (bool, optional) – 当 master_onlyTrue 时,此标志 控制是否每个主机的 master(如果 global_masterFalse) 保存内容,或者仅全局 master(序号为 0)保存内容。 默认值:False

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.wait_device_ops() # wait for all pending operations to finish.
>>> xm.save(obj_to_save, path_to_save)
>>> xm.rendezvous('torch_xla.core.xla_model.save') # multi process context only
torch_xla.core.xla_model.rendezvous(tag: str, payload: bytes = b'', replicas: List[int] = []) List[bytes][source]

等待所有网格客户端到达指定的集合点。

注意:PJRT不支持XRT网格服务器,因此这实际上是xla_rendezvous的别名。

Parameters
  • tag (string) – 要加入的集合点的名称。

  • payload (bytes, 可选) – 要发送到集合点的有效载荷。

  • replicas (list, python:int) – 参与集合的副本序号。 空列表表示网格中的所有副本。 默认值:[]

Returns

所有其他核心交换的有效载荷,核心序号i的有效载荷位于返回元组中的位置i

示例

>>> import torch_xla.core.xla_model as xm
>>> xm.rendezvous('example')
torch_xla.core.xla_model.mesh_reduce(tag: str, data, reduce_fn: Callable[[...], Any]) Union[Any, ToXlaTensorArena][source]

执行图外客户端网格简化。

Parameters
  • tag (string) – 要加入的集合点的名称。

  • data – 要减少的数据。reduce_fn 可调用对象将接收一个列表,其中包含来自所有网格客户端进程(每个核心一个)的相同数据的副本。

  • reduce_fn (callable) – 一个接收data类对象列表并返回简化结果的函数。

Returns

减少的值。

示例

>>> import torch_xla.core.xla_model as xm
>>> import numpy as np
>>> accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
torch_xla.core.xla_model.set_rng_state(seed: int, device: Optional[str] = None)[source]

设置随机数生成器的状态。

Parameters
  • seed (python:integer) – 要设置的状态。

  • device (string, optional) – 需要设置RNG状态的设备。 如果缺失,将设置默认设备种子。

torch_xla.core.xla_model.get_rng_state(device: Optional[str] = None) int[source]

获取当前运行的随机数生成器状态。

Parameters

device (string, optional) – 需要获取RNG状态的设备。 如果缺失,将设置默认设备种子。

Returns

RNG状态,作为整数。

torch_xla.core.xla_model.get_memory_info(device: Optional[device] = None) MemoryInfo[source]

检索设备内存使用情况。

Parameters
  • device – Optional[torch.device] 请求内存信息的设备。

  • device. (如果未传递,将使用默认值) –

Returns

MemoryInfo 字典包含给定设备的内存使用情况。

示例

>>> xm.get_memory_info()
{'bytes_used': 290816, 'bytes_limit': 34088157184}
torch_xla.core.xla_model.get_stablehlo(tensors: Optional[List[Tensor]] = None) str[source]

获取计算图的StableHLO字符串格式。

如果tensors不为空,将以tensors为输出的图进行转储。 如果tensors为空,将转储整个计算图。

对于推理图,建议将模型输出传递给tensors。 对于训练图,识别“输出”并不直接。建议使用空的tensors

要在StableHLO中启用源代码行信息,请设置环境变量XLA_HLO_DEBUG=1。

Parameters

tensors (list[torch.Tensor], optional) – 表示StableHLO图输出/根的张量。

Returns

字符串格式的StableHLO模块。

torch_xla.core.xla_model.get_stablehlo_bytecode(tensors: Optional[Tensor] = None) bytes[source]

获取计算图的StableHLO字节码格式。

如果tensors不为空,将以tensors为输出的图进行转储。 如果tensors为空,将转储整个计算图。

对于推理图,建议将模型输出传递给tensors。 对于训练图,识别“输出”并不直接。建议使用空的tensors

Parameters

tensors (list[torch.Tensor], optional) – 表示StableHLO图输出/根的张量。

Returns

字节码格式的StableHLO模块。

分布式

class torch_xla.distributed.parallel_loader.MpDeviceLoader(loader, device, **kwargs)[source]

使用后台数据上传包装现有的PyTorch DataLoader。

此类应仅用于多进程数据并行。它将使用ParallelLoader包装传入的数据加载器,并返回当前设备的per_device_loader。

Parameters
  • loader (torch.utils.data.DataLoader) – 要包装的PyTorch DataLoader。

  • device (torch.device…) – 数据需要发送到的设备。

  • kwargsParallelLoader 构造函数的命名参数。

示例

>>> device = torch_xla.device()
>>> train_device_loader = MpDeviceLoader(train_loader, device)
torch_xla.distributed.xla_multiprocessing.spawn(fn, args=(), nprocs=None, join=True, daemon=False, start_method='spawn')[source]

启用基于多处理的复制。

Parameters
  • fn (callable) – 为每个参与复制的设备调用的函数。该函数将被调用,第一个参数是复制中进程的全局索引,后面是args中传递的参数。

  • args (元组) – fn 的参数。 默认值:空元组

  • nprocs (python:int) – 用于复制的进程/设备的数量。目前,如果指定,可以是1或设备的最大数量。

  • join (bool) – 调用是否应该阻塞等待已生成的进程完成。 默认值:True

  • daemon (bool) – 是否应将生成的进程设置为daemon标志(参见Python多处理API)。 默认值:False

  • start_method (string) – Python multiprocessing 进程创建方法。 默认值:spawn

Returns

torch.multiprocessing.spawn API返回的相同对象。如果nprocs为1,则fn函数将直接调用,并且API将返回None。

spmd

torch_xla.distributed.spmd.mark_sharding(t: Union[Tensor, XLAShardedTensor], mesh: 网格, partition_spec: Tuple[Optional[Union[Tuple, int, str]]]) XLAShardedTensor[source]

使用XLA分区规范对提供的张量进行注释。在内部,它为XLA SpmdPartitioner传递注释相应的XLATensor为分片。

Parameters
  • t (Union[torch.Tensor, XLAShardedTensor]) – 输入张量,用于标注partition_spec。

  • mesh (网格) – 描述逻辑XLA设备拓扑和底层设备ID。

  • partition_spec (Tuple[Tuple, python:int, str, None]) – 一个包含设备网格维度索引或None的元组。每个索引是一个整数,如果网格轴有名称则为字符串,或者是整数或字符串的元组。这指定了每个输入秩是如何分片(索引到网格形状)或复制(None)的。当指定一个元组时,相应的输入张量轴将沿着元组中的所有逻辑轴进行分片。请注意,元组中指定的网格轴的顺序会影响最终的分片结果。

  • dynamo_custom_op (bool) – 如果设置为True,它将调用mark_sharding的dynamo自定义操作变体,以便dynamo能够识别和跟踪它。

示例

>>> import torch_xla.runtime as xr
>>> import torch_xla.distributed.spmd as xs
>>> mesh_shape = (4, 2)
>>> num_devices = xr.global_runtime_device_count()
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> input = torch.randn(8, 32).to(xm.xla_device())
>>> xs.mark_sharding(input, mesh, (0, None)) # 4-way data parallel
>>> linear = nn.Linear(32, 10).to(xm.xla_device())
>>> xs.mark_sharding(linear.weight, mesh, (None, 1)) # 2-way model parallel
torch_xla.distributed.spmd.clear_sharding(t: Union[Tensor, XLAShardedTensor]) Tensor[source]

从输入张量中清除分片注释并返回一个cpu转换的张量。这是一个原地操作,但也会返回相同的torch.Tensor。

Parameters

t (Union[torch.Tensor, XLAShardedTensor]) – 我们想要清除分片的张量

Returns

没有分片的张量。

Return type

t (torch.Tensor)

示例

>>> import torch_xla.distributed.spmd as xs
>>> torch_xla.runtime.use_spmd()
>>> t1 = torch.randn(8,8).to(torch_xla.device())
>>> mesh = xs.get_1d_mesh()
>>> xs.mark_sharding(t1, mesh, (0, None))
>>> xs.clear_sharding(t1)
torch_xla.distributed.spmd.set_global_mesh(mesh: 网格)[source]

设置可用于当前进程的全局网格。

Parameters

mesh – (Mesh) 将成为全局网格的网格对象。

示例

>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> xs.set_global_mesh(mesh)
torch_xla.distributed.spmd.get_global_mesh() Optional[网格][source]

获取当前进程的全局网格。

Returns

(可选[Mesh])如果设置了全局网格,则返回网格对象,否则返回None。

Return type

网格

示例

>>> import torch_xla.distributed.spmd as xs
>>> xs.get_global_mesh()
torch_xla.distributed.spmd.get_1d_mesh(axis_name: Optional[str] = None) 网格[source]

辅助函数,用于返回包含所有设备的一维网格。

Parameters

axis_name – (Optional[str]) 可选的字符串,用于表示网格的轴名称

Returns

网格对象

Return type

Mesh

示例

>>> # This example is assuming 1 TPU v4-8
>>> import torch_xla.distributed.spmd as xs
>>> mesh = xs.get_1d_mesh("data")
>>> print(mesh.mesh_shape)
(4,)
>>> print(mesh.axis_names)
('data',)
class torch_xla.distributed.spmd.Mesh(device_ids: Union[ndarray, List], mesh_shape: Tuple[int, ...], axis_names: Optional[Tuple[str, ...]] = None)[source]

描述逻辑XLA设备拓扑网格及其底层资源。

Parameters
  • device_ids (Union[np.ndarray, List]) – 一个按自定义顺序展开的设备(ID)列表。该列表被重塑为一个mesh_shape数组,使用类似C语言的索引顺序填充元素。

  • mesh_shape (Tuple[python:int, ...]) – 一个描述设备网格逻辑拓扑形状的整数元组,每个元素描述了相应轴上的设备数量。

  • axis_names (Tuple[str, ...]) – 一个资源轴名称的序列,用于分配给devices参数的维度。其长度应与devices的秩相匹配。

示例

>>> mesh_shape = (4, 2)
>>> num_devices = len(xm.get_xla_supported_devices())
>>> device_ids = np.array(range(num_devices))
>>> mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
>>> mesh.get_logical_mesh()
>>> array([[0, 1],
          [2, 3],
          [4, 5],
          [6, 7]])
>>> mesh.shape()
OrderedDict([('x', 4), ('y', 2)])
class torch_xla.distributed.spmd.HybridMesh(*, ici_mesh_shape: Tuple[int, ...], dcn_mesh_shape: Optional[Tuple[int, ...]] = None, axis_names: Optional[Tuple[str, ...]] = None)[source]
Creates a hybrid device mesh of devices connected with ICI and DCN networks.

逻辑网格的形状应按网络强度递增的顺序排列,例如 [replica, data, model],其中 mdl 具有最多的网络通信需求。

Parameters
  • ici_mesh_shape – 内部连接设备的逻辑网格形状。

  • dcn_mesh_shape – 外部连接设备的逻辑网格形状。

示例

>>> # This example is assuming 2 slices of v4-8.
>>> ici_mesh_shape = (1, 4, 1) # (data, fsdp, tensor)
>>> dcn_mesh_shape = (2, 1, 1)
>>> mesh = HybridMesh(ici_mesh_shape, dcn_mesh_shape, ('data','fsdp','tensor'))
>>> print(mesh.shape())
>>> >> OrderedDict([('data', 2), ('fsdp', 4), ('tensor', 1)])

实验性

torch_xla.experimental.eager_mode(enable: bool)[source]

配置 torch_xla 的默认执行模式。

在急切模式下,只有被`torch_xla.compile`编译的函数才会被追踪和编译。其他torch操作将会被急切执行。

调试

torch_xla.debug.metrics.metrics_report()[source]

检索包含完整指标和计数器报告的字符串。

torch_xla.debug.metrics.short_metrics_report(counter_names: Optional[list] = None, metric_names: Optional[list] = None)[source]

检索包含完整指标和计数器报告的字符串。

Parameters
  • counter_names (list) – 需要打印数据的计数器名称列表。

  • metric_names (list) – 需要打印数据的指标名称列表。

torch_xla.debug.metrics.counter_names()[source]

检索所有当前活动的计数器名称。

torch_xla.debug.metrics.counter_value(name)[source]

返回活动计数器的值。

Parameters

name (string) – 需要获取值的计数器的名称。

Returns

计数器值为整数。

torch_xla.debug.metrics.metric_names()[source]

检索所有当前活动的指标名称。

torch_xla.debug.metrics.metric_data(name)[source]

返回活动指标的数据。

Parameters

name (string) – 需要检索数据的指标名称。

Returns

指标数据,它是一个元组,包含(TOTAL_SAMPLES, ACCUMULATOR, SAMPLES)。 TOTAL_SAMPLES是已发布到该指标的总样本数。一个指标只保留给定数量的样本(在循环缓冲区中)。 ACCUMULATORTOTAL_SAMPLES上样本的总和。 SAMPLES是一个(TIME, VALUE)元组的列表。