内存高效聚合

MessagePassing 接口在 中依赖于一种聚集-分散方案来聚合来自相邻节点的消息。 例如,考虑消息传递层

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i),\]

可以实现为:

from torch_geometric.nn import MessagePassing

x = ...           # Node features of shape [num_nodes, num_features]
edge_index = ...  # Edge indices of shape [2, num_edges]

class MyConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        return MLP(x_j - x_i)

在底层,MessagePassing 实现生成的代码如下所示:

from torch_geometric.utils import scatter

x = ...           # Node features of shape [num_nodes, num_features]
edge_index = ...  # Edge indices of shape [2, num_edges]

x_j = x[edge_index[0]]  # Source node features [num_edges, num_features]
x_i = x[edge_index[1]]  # Target node features [num_edges, num_features]

msg = MLP(x_j - x_i)  # Compute message for each edge

# Aggregate messages based on target node indices
out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')

虽然聚集-分散的公式推广了许多有用的GNN实现,但它有一个缺点,即显式地实例化x_jx_i,导致在大型和密集的图上内存占用较高。

幸运的是,并非所有的GNN都需要通过显式地实例化x_j和/或x_i来实现。 在某些情况下,GNN也可以实现为简单的稀疏矩阵乘法。 作为一个通用的经验法则,这对于在计算消息时不使用中心节点特征x_i或多维边特征的GNN是成立的。 例如,GINConv

\[\mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),\]

等同于计算

\[\mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right),\]

其中 \(\mathbf{A}\) 表示形状为 [num_nodes, num_nodes] 的稀疏邻接矩阵。 这种表述允许利用专门且快速的稀疏矩阵乘法实现。

PyG >= 1.6.0 中,我们正式引入了对稀疏矩阵乘法 GNN 的更好支持,从而实现了 更低的内存占用更快的执行时间。 因此,我们引入了 SparseTensor 类(来自 torch_sparse 包),该类基于 “Design Principles for Sparse Matrix Multiplication on the GPU” 论文实现了稀疏矩阵乘法的快速前向和反向传播。

使用SparseTensor类非常简单,类似于scipy处理稀疏矩阵的方式:

from torch_sparse import SparseTensor

adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
                   sparse_sizes=(num_nodes, num_nodes))
# value is optional and can be None

# Obtain different representations (COO, CSR, CSC):
row,    col, value = adj.coo()
rowptr, col, value = adj.csr()
colptr, row, value = adj.csc()

adj = adj[:100, :100]  # Slicing, indexing and masking support
adj = adj.set_diag()   # Add diagonal entries
adj_t = adj.t()        # Transpose
out = adj.matmul(x)    # Sparse-dense matrix multiplication
adj = adj.matmul(adj)  # Sparse-sparse matrix multiplication

# Creating SparseTensor instances:
adj = SparseTensor.from_dense(mat)
adj = SparseTensor.eye(100, 100)
adj = SparseTensor.from_scipy(mat)

我们的 MessagePassing 接口可以处理 torch.TensorSparseTensor 作为消息传播的输入。 然而,当在 SparseTensor 中持有有向图时,你需要确保将 转置的稀疏矩阵 输入到 propagate() 中:

conv = GCNConv(16, 32)
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)

conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32)))
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)

为了利用稀疏矩阵乘法,MessagePassing 接口引入了 message_and_aggregate() 函数(它将 message()aggregate() 函数融合为一个计算步骤),每当它被实现并接收到 SparseTensor 作为 edge_index 的输入时,该函数就会被调用。 有了它,GINConv 层现在可以如下实现:

import torch_sparse

class GINConv(MessagePassing):
    def __init__(self):
        super().__init__(aggr="add")

    def forward(self, x, edge_index):
        out = self.propagate(edge_index, x=x)
        return MLP((1 + eps) x + out)

    def message(self, x_j):
        return x_j

    def message_and_aggregate(self, adj_t, x):
        return torch_sparse.matmul(adj_t, x, reduce=self.aggr)

使用新的SparseTensor格式非常简单,因为我们的所有GNN都可以直接使用它。 要将edge_index格式转换为新引入的SparseTensor格式,你可以使用torch_geometric.transforms.ToSparseTensor转换:

import torch
import torch.nn.functional as F

from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())
data = dataset[0]
>>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)


class GNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
        self.conv2 = GCNConv(16, dataset.num_classes, cached=True)

    def forward(self, x, adj_t):
        x = self.conv1(x, adj_t)
        x = F.relu(x)
        x = self.conv2(x, adj_t)
        return F.log_softmax(x, dim=1)

model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.adj_t)
    loss = F.nll_loss(out, data.y)
    loss.backward()
    optimizer.step()
    return float(loss)

for epoch in range(1, 201):
    loss = train(data)

所有代码保持不变,除了通过T.ToSparseTensor()进行的data转换。 作为一个额外的优势,利用SparseTensor类的MessagePassing实现在GPU上是确定性的,因为聚合不再依赖于原子操作。

值得注意的是,当GNN层将单维或多维边缘信息edge_weightedge_attr分别纳入其消息传递公式时,GNN层的执行会略有变化。 特别是,现在期望这些属性直接作为值添加到SparseTensor对象中。 而不是像这样调用GNN:

conv = GMMConv(16, 32, dim=3)
out = conv(x, edge_index, edge_attr)

我们现在执行我们的GNN操作符如下

conv = GMMConv(16, 32, dim=3)
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr)
out = conv(x, adj.t())

注意

由于此功能仍处于实验阶段,某些操作,例如,图池化方法,可能仍需要您输入edge_index格式。 您可以通过以下方式将adj_t转换回(edge_index, edge_attr)

row, col, edge_attr = adj_t.t().coo()
edge_index = torch.stack([row, col], dim=0)

请告诉我们您对SparseTensor的看法,我们如何改进它,以及您遇到任何意外行为时的情况。