编译图神经网络
torch.compile() 是在 torch >= 2.0.0 中加速你的 PyTorch 代码的最新方法!
torch.compile() 通过将 PyTorch 代码 JIT 编译为优化的内核来使其运行得更快,同时只需要最少的代码更改。
在底层,torch.compile() 通过 TorchDynamo 捕获 PyTorch 程序,通过 PrimTorch 规范化超过 2,000 个 PyTorch 操作符,最后通过深度学习编译器 TorchInductor 在多个加速器和后端生成快速代码。
在本教程中,我们展示了如何通过torch.compile()优化您的自定义PyG模型。
注意
从 PyG 2.5(及以后版本)开始,torch.compile() 现在完全兼容所有 PyG GNN 层。
如果您使用的是较早版本的 PyG,请考虑使用 torch_geometric.compile() 代替。
基本用法
一旦你定义了一个PyG模型,只需用torch.compile()包装它,即可获得其优化版本:
import torch
from torch_geometric.nn import GraphSAGE
model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)
model = model.to(device)
model = torch.compile(model)
并像往常一样执行它:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root, name="Cora")
data = dataset[0].to(device)
out = model(data.x, data.edge_index)
最大化性能
torch.compile() 方法提供了两个重要的参数需要注意:
在PyG中观察到的大多数小批量数据本质上是动态的,这意味着它们的形状在不同的批次之间会有所不同。 对于这些情况,我们可以通过在PyTorch中使用
dynamic=True参数来强制进行动态形状跟踪:torch.compile(model, dynamic=True)
通过这种方式,PyTorch 将预先尝试生成一个尽可能动态的内核,以避免在小批量大小变化时重新编译。 请注意,当
dynamic设置为False时,PyTorch 将永远不会生成动态内核,因此仅在图形大小保证永远不会变化时有效(例如,在小图形的全批量训练中)。 默认情况下,dynamic在 PyTorch>= 2.1.0中设置为None,并且 PyTorch 将自动检测是否发生了动态变化。 请注意,动态形状跟踪的支持需要安装 PyTorch>= 2.1.0。为了最大化加速,编译模型中的图形中断应该被限制。 我们可以通过在遇到第一个图形中断时强制编译引发错误,使用
fullgraph=True参数来实现:torch.compile(model, fullgraph=True)
通常,确认您编写的模型不包含任何图形中断是一个良好的做法。 重要的是,PyG中存在一些操作目前会导致图形中断(但存在解决方法),例如:
global_mean_pool()(以及其他池化操作符)在没有传递批量大小size的情况下执行设备同步,导致图形中断。remove_self_loops()和add_remaining_self_loops()会屏蔽给定的edge_index,导致设备同步以计算其最终输出形状。 因此,我们建议在将图输入到GNN之前对其进行增强,例如,通过AddSelfLoops或GCNNorm转换,并在初始化层(如GCNConv)时设置add_self_loops=False/normalize=False。
示例脚本
我们在examples/compile中包含了多个示例,进一步展示了torch.compile()的实际用法:
如果您注意到torch.compile()对于某个PyG模型失败,请不要犹豫,在 GitHub或 Slack上联系我们。
我们非常渴望在整个PyG代码库中改进torch.compile()的支持。
基准测试
torch.compile() 对许多 PyG 模型效果非常好。
总的来说,我们观察到运行时间提高了近300%。
具体来说,我们对GCN、GraphSAGE和GIN进行了基准测试,并比较了从传统急切模式和torch.compile()获得的运行时间。
我们使用了一个包含10,000个节点和200,000条边的合成图,隐藏特征维度为64。
我们报告了500次优化步骤的运行时间:
模型 |
模式 |
前进 |
向后 |
总计 |
加速 |
|---|---|---|---|---|---|
渴望 |
2.6396秒 |
2.1697秒 |
4.8093秒 |
||
已编译 |
1.1082秒 |
0.5896秒 |
1.6978秒 |
2.83x |
|
渴望 |
1.6023秒 |
1.6428秒 |
3.2451秒 |
||
已编译 |
0.7033秒 |
0.7465秒 |
1.4498s |
2.24x |
|
渴望 |
1.6701秒 |
1.6990秒 |
3.3690秒 |
||
已编译 |
0.7320秒 |
0.7407秒 |
1.4727秒 |
2.29x |
要重现这些结果,请运行
python test/nn/models/test_basic_gnn.py
从你从GitHub检出的PyG仓库的根文件夹开始。