分层邻域采样
PyG 的设计原则之一是模型和数据加载例程应该是可互换的,以便进行灵活的 GNN 和数据加载实验。因此,模型通常可以以数据加载无关的方式编写,独立于是否通过 DataLoader、NeighborLoader 或 ClusterLoader 应用全批次或小批次训练策略。然而,在某些情况下,这种灵活性是以性能为代价的,因为模型无法利用底层数据加载例程的特殊特性。一个这样的限制是,使用 NeighborLoader 例程训练的 GNN 会迭代地为网络的所有深度的所有节点构建表示,尽管在后续跳数中采样的节点不再对后续 GNN 层中的种子节点的表示做出贡献,从而执行无用的计算。
分层邻域采样 或 分层图邻接矩阵 (HGAM) 是 PyG 中可用的一种技术,用于消除这种开销并加速小批量 GNN 的训练和推理。 其主要思想是在将返回的子图的邻接矩阵输入到每个 GNN 层之前,逐步修剪它。 它在多个模型中无缝工作,基本上减少了为给定小批量的种子节点生成表示所需的计算量。
至关重要的是,HGAM认识到最终节点表示的计算仅对种子节点(这是批量计算的真正目标)是必要的。 因此,HGAM允许GNN的每一层仅计算该层所需的节点表示,从而减少计算量并加速训练过程,这种加速随着GNN深度的增加而增长。 在实践中,这是通过在GNN层之间进行计算时修剪邻接矩阵和各种特征矩阵来实现的。 这与以下事实一致:为了计算种子/目标节点的表示(通过采样方法从中构建小批量),相关邻域的深度随着我们通过GNN层的进行而缩小。 HGAM应用的修剪是可能的,因为通过采样构建的子图的节点按照广度优先搜索(BFS)策略排序,这意味着邻接矩阵的行和列指的是从种子节点(以任何顺序)开始,然后是第一个种子节点的1跳邻居,接着是第二个种子节点的1跳采样邻居,依此类推的节点排序。 小批量中节点的BFS排序允许对子图的邻接矩阵进行增量修剪(减少)。 由于BFS排序使得距离种子节点更远的节点在有序节点列表中出现得更远,这种逐步修剪以计算上方便的方式进行。
为了支持这种修剪并有效实施,NeighborLoader 在 PyG 和 pyg-lib 中的实现还返回了在每一跳中采样的节点和边的数量。这些信息允许快速操作邻接矩阵,从而大大减少计算量。NeighborLoader 通过专用属性 num_sampled_nodes 和 num_sampled_edges 准备这些元数据。可以从为同质图和异质图返回的 Batch 对象中访问这些信息。
总之,HGAM 是一种特殊的数据结构,能够在 NeighborLoader 场景中实现高效的消息传递计算。
HGAM 在 PyG 中实现,可以通过特殊的 trim_to_layer() 功能来使用。
HGAM 目前是 PyG 用户可以自由选择开启或关闭的选项 (当前默认关闭)。
用法
在这里,我们展示了如何结合使用HGAM功能和NeighborLoader的示例:
同质数据示例:
from torch_geometric.datasets import Planetoid from torch_geometric.loader import NeighborLoader data = Planetoid(path, name='Cora')[0] loader = NeighborLoader( data, num_neighbors=[10] * 3, batch_size=128, ) batch = next(iter(loader)) print(batch) >>> Data(x=[1883, 1433], edge_index=[2, 5441], y=[1883], train_mask=[1883], val_mask=[1883], test_mask=[1883], batch_size=128, num_sampled_nodes=[4], num_sampled_edges=[3]) print(batch.num_sampled_nodes) >>> [128, 425, 702, 628] # Number of sampled nodes per hop/layer. print(batch.num_sampled_edges) >>> [520, 2036, 2885] # Number of sampled edges per hop/layer.
异构数据示例:
from torch_geometric.datasets import OGB_MAG from torch_geometric.loader import NeighborLoader data = OGB_MAG(path)[0] loader = NeighborLoader( data, num_neighbors=[10] * 3, batch_size=128, input_nodes='paper', ) batch = next(iter(loader)) print(batch) >>> HeteroData( paper={ x=[2275, 128], num_sampled_nodes=[3], batch_size=128, }, author={ num_nodes=2541, num_sampled_nodes=[3], }, institution={ num_nodes=0, num_sampled_nodes=[3], }, field_of_study={ num_nodes=0, num_sampled_nodes=[3], }, (author, affiliated_with, institution)={ edge_index=[2, 0], num_sampled_edges=[2], }, (author, writes, paper)={ edge_index=[2, 3255], num_sampled_edges=[2], }, (paper, cites, paper)={ edge_index=[2, 2691], num_sampled_edges=[2], }, (paper, has_topic, field_of_study)={ edge_index=[2, 0], num_sampled_edges=[2], } ) print(batch['paper'].num_sampled_nodes) >>> [128, 508, 1598] # Number of sampled paper nodes per hop/layer. print(batch['author', 'writes', 'paper'].num_sampled_edges) >>>> [629, 2621] # Number of sampled autor<>paper edges per hop/layer.
属性 num_sampled_nodes 和 num_sampled_edges 可以被 GNN 中的 trim_to_layer() 函数使用:
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import trim_to_layer
dataset = Reddit(path)
loader = NeighborLoader(data, num_neighbors=[10, 5, 5], ...)
class GNN(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, num_layers: int):
super().__init__()
self.convs = ModuleList([SAGEConv(in_channels, 64)])
for _ in range(num_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.lin = Linear(hidden_channels, out_channels)
def forward(
self,
x: Tensor,
edge_index: Tensor,
num_sampled_nodes_per_hop: List[int],
num_sampled_edges_per_hop: List[int],
) -> Tensor:
for i, conv in enumerate(self.convs):
# Trim edge and node information to the current layer `i`.
x, edge_index, _ = trim_to_layer(
i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
x, edge_index)
x = conv(x, edge_index).relu()
return self.lin(x)
示例
我们在PyG的examples/文件夹中提供了HGAM的完整示例: