编译图神经网络

torch.compile() 是在 torch >= 2.0.0 中加速你的 代码的最新方法! torch.compile() 通过将 PyTorch 代码 JIT 编译为优化的内核来使其运行得更快,同时只需要最少的代码更改。

在底层,torch.compile() 通过 TorchDynamo 捕获 程序,通过 PrimTorch 规范化超过 2,000 个 操作符,最后通过深度学习编译器 TorchInductor 在多个加速器和后端生成快速代码。

注意

请参阅这里了解如何利用torch.compile()的通用教程,以及这里了解其接口的描述。

在本教程中,我们展示了如何通过torch.compile()优化您的自定义模型。

注意

2.5(及以后版本)开始,torch.compile() 现在完全兼容所有 GNN 层。 如果您使用的是较早版本的 ,请考虑使用 torch_geometric.compile() 代替。

基本用法

一旦你定义了一个模型,只需用torch.compile()包装它,即可获得其优化版本:

import torch
from torch_geometric.nn import GraphSAGE

model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)
model = model.to(device)

model = torch.compile(model)

并像往常一样执行它:

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root, name="Cora")
data = dataset[0].to(device)

out = model(data.x, data.edge_index)

最大化性能

torch.compile() 方法提供了两个重要的参数需要注意:

  • 中观察到的大多数小批量数据本质上是动态的,这意味着它们的形状在不同的批次之间会有所不同。 对于这些情况,我们可以通过在中使用dynamic=True参数来强制进行动态形状跟踪:

    torch.compile(model, dynamic=True)
    

    通过这种方式, 将预先尝试生成一个尽可能动态的内核,以避免在小批量大小变化时重新编译。 请注意,当 dynamic 设置为 False 时,永远不会生成动态内核,因此仅在图形大小保证永远不会变化时有效(例如,在小图形的全批量训练中)。 默认情况下,dynamic >= 2.1.0 中设置为 None,并且 将自动检测是否发生了动态变化。 请注意,动态形状跟踪的支持需要安装 >= 2.1.0

  • 为了最大化加速,编译模型中的图形中断应该被限制。 我们可以通过在遇到第一个图形中断时强制编译引发错误,使用fullgraph=True参数来实现:

    torch.compile(model, fullgraph=True)
    

    通常,确认您编写的模型不包含任何图形中断是一个良好的做法。 重要的是,中存在一些操作目前会导致图形中断(但存在解决方法),例如

    1. global_mean_pool()(以及其他池化操作符)在没有传递批量大小size的情况下执行设备同步,导致图形中断。

    2. remove_self_loops()add_remaining_self_loops() 会屏蔽给定的 edge_index,导致设备同步以计算其最终输出形状。 因此,我们建议在将图输入到GNN之前对其进行增强,例如,通过 AddSelfLoopsGCNNorm 转换,并在初始化层(如 GCNConv)时设置 add_self_loops=False/normalize=False

示例脚本

我们在examples/compile中包含了多个示例,进一步展示了torch.compile()的实际用法:

  1. 节点分类 通过 GCN (dynamic=False)

  2. 图分类 通过 GIN (dynamic=True)

如果您注意到torch.compile()对于某个模型失败,请不要犹豫,在 GitHub Slack上联系我们。 我们非常渴望在整个代码库中改进torch.compile()的支持。

基准测试

torch.compile() 对许多 模型效果非常好总的来说,我们观察到运行时间提高了近300%。

具体来说,我们对GCNGraphSAGEGIN进行了基准测试,并比较了从传统急切模式和torch.compile()获得的运行时间。 我们使用了一个包含10,000个节点和200,000条边的合成图,隐藏特征维度为64。 我们报告了500次优化步骤的运行时间:

模型

模式

前进

向后

总计

加速

GCN

渴望

2.6396秒

2.1697秒

4.8093秒

GCN

已编译

1.1082秒

0.5896秒

1.6978秒

2.83x

GraphSAGE

渴望

1.6023秒

1.6428秒

3.2451秒

GraphSAGE

已编译

0.7033秒

0.7465秒

1.4498s

2.24x

GIN

渴望

1.6701秒

1.6990秒

3.3690秒

GIN

已编译

0.7320秒

0.7407秒

1.4727秒

2.29x

要重现这些结果,请运行

python test/nn/models/test_basic_gnn.py

从你从检出的仓库的根文件夹开始。