分层邻域采样

的设计原则之一是模型和数据加载例程应该是可互换的,以便进行灵活的 GNN 和数据加载实验。因此,模型通常可以以数据加载无关的方式编写,独立于是否通过 DataLoaderNeighborLoaderClusterLoader 应用全批次或小批次训练策略。然而,在某些情况下,这种灵活性是以性能为代价的,因为模型无法利用底层数据加载例程的特殊特性。一个这样的限制是,使用 NeighborLoader 例程训练的 GNN 会迭代地为网络的所有深度的所有节点构建表示,尽管在后续跳数中采样的节点不再对后续 GNN 层中的种子节点的表示做出贡献,从而执行无用的计算。

分层邻域采样分层图邻接矩阵 (HGAM) 中可用的一种技术,用于消除这种开销并加速小批量 GNN 的训练和推理。 其主要思想是在将返回的子图的邻接矩阵输入到每个 GNN 层之前,逐步修剪它。 它在多个模型中无缝工作,基本上减少了为给定小批量的种子节点生成表示所需的计算量。

至关重要的是,HGAM认识到最终节点表示的计算仅对种子节点(这是批量计算的真正目标)是必要的。 因此,HGAM允许GNN的每一层仅计算该层所需的节点表示,从而减少计算量并加速训练过程,这种加速随着GNN深度的增加而增长。 在实践中,这是通过在GNN层之间进行计算时修剪邻接矩阵和各种特征矩阵来实现的。 这与以下事实一致:为了计算种子/目标节点的表示(通过采样方法从中构建小批量),相关邻域的深度随着我们通过GNN层的进行而缩小。 HGAM应用的修剪是可能的,因为通过采样构建的子图的节点按照广度优先搜索(BFS)策略排序,这意味着邻接矩阵的行和列指的是从种子节点(以任何顺序)开始,然后是第一个种子节点的1跳邻居,接着是第二个种子节点的1跳采样邻居,依此类推的节点排序。 小批量中节点的BFS排序允许对子图的邻接矩阵进行增量修剪(减少)。 由于BFS排序使得距离种子节点更远的节点在有序节点列表中出现得更远,这种逐步修剪以计算上方便的方式进行。

为了支持这种修剪并有效实施,NeighborLoader 中的实现还返回了在每一跳中采样的节点和边的数量。这些信息允许快速操作邻接矩阵,从而大大减少计算量。NeighborLoader 通过专用属性 num_sampled_nodesnum_sampled_edges 准备这些元数据。可以从为同质图和异质图返回的 Batch 对象中访问这些信息。

总之,HGAM 是一种特殊的数据结构,能够在 NeighborLoader 场景中实现高效的消息传递计算。 HGAM 在 中实现,可以通过特殊的 trim_to_layer() 功能来使用。 HGAM 目前是 用户可以自由选择开启或关闭的选项 (当前默认关闭)

用法

在这里,我们展示了如何结合使用HGAM功能和NeighborLoader的示例:

  • 同质数据示例:

    from torch_geometric.datasets import Planetoid
    from torch_geometric.loader import NeighborLoader
    
    data = Planetoid(path, name='Cora')[0]
    
    loader = NeighborLoader(
        data,
        num_neighbors=[10] * 3,
        batch_size=128,
    )
    
    batch = next(iter(loader))
    print(batch)
    >>> Data(x=[1883, 1433], edge_index=[2, 5441], y=[1883], train_mask=[1883],
             val_mask=[1883], test_mask=[1883], batch_size=128,
             num_sampled_nodes=[4], num_sampled_edges=[3])
    
    print(batch.num_sampled_nodes)
    >>> [128, 425, 702, 628]  # Number of sampled nodes per hop/layer.
    print(batch.num_sampled_edges)
    >>> [520, 2036, 2885]  # Number of sampled edges per hop/layer.
    
  • 异构数据示例:

    from torch_geometric.datasets import OGB_MAG
    from torch_geometric.loader import NeighborLoader
    
    data = OGB_MAG(path)[0]
    
    loader = NeighborLoader(
        data,
        num_neighbors=[10] * 3,
        batch_size=128,
        input_nodes='paper',
    )
    
    batch = next(iter(loader))
    print(batch)
    >>> HeteroData(
        paper={
            x=[2275, 128],
            num_sampled_nodes=[3],
            batch_size=128,
        },
        author={
            num_nodes=2541,
            num_sampled_nodes=[3],
        },
        institution={
            num_nodes=0,
            num_sampled_nodes=[3],
        },
        field_of_study={
            num_nodes=0,
            num_sampled_nodes=[3],
        },
        (author, affiliated_with, institution)={
            edge_index=[2, 0],
            num_sampled_edges=[2],
        },
        (author, writes, paper)={
            edge_index=[2, 3255],
            num_sampled_edges=[2],
        },
        (paper, cites, paper)={
            edge_index=[2, 2691],
            num_sampled_edges=[2],
        },
        (paper, has_topic, field_of_study)={
            edge_index=[2, 0],
            num_sampled_edges=[2],
        }
        )
    print(batch['paper'].num_sampled_nodes)
    >>> [128, 508, 1598]  # Number of sampled paper nodes per hop/layer.
    
    print(batch['author', 'writes', 'paper'].num_sampled_edges)
    >>>> [629, 2621]  # Number of sampled autor<>paper edges per hop/layer.
    

属性 num_sampled_nodesnum_sampled_edges 可以被 GNN 中的 trim_to_layer() 函数使用:

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import trim_to_layer

dataset = Reddit(path)
loader = NeighborLoader(data, num_neighbors=[10, 5, 5], ...)

class GNN(torch.nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int):
        super().__init__()

        self.convs = ModuleList([SAGEConv(in_channels, 64)])
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.lin = Linear(hidden_channels, out_channels)

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        num_sampled_nodes_per_hop: List[int],
        num_sampled_edges_per_hop: List[int],
    ) -> Tensor:

        for i, conv in enumerate(self.convs):
            # Trim edge and node information to the current layer `i`.
            x, edge_index, _ = trim_to_layer(
                i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,
                x, edge_index)

            x = conv(x, edge_index).relu()

        return self.lin(x)

示例

我们在examples/文件夹中提供了HGAM的完整示例:

  • examples/hierarchical_sampling.py: 一个示例,展示HGAM的基本用法。

  • examples/hetero/hierarchical_sage.py: 一个关于异构图上的HGAM的示例