5.4 图分类
有时候,数据可能以多个图的形式存在,而不是一个大的单一图,例如不同类型的人群社区列表。通过用图来描述同一社区中人们之间的友谊关系,可以得到一个图列表进行分类。在这种情况下,图分类模型可以帮助识别社区的类型,即根据结构和整体信息对每个图进行分类。
概述
图分类与节点分类或链接预测的主要区别在于,预测结果表征了整个输入图的属性。可以像之前的任务一样在节点/边上执行消息传递,但也需要检索图级别的表示。
图分类流程如下:

图分类过程
从左到右,常见的做法是:
准备一批图表
在批处理图上执行消息传递以更新节点/边特征
将节点/边的特征聚合为图级别的表示
基于图级表示对图进行分类
图批次
通常,图分类任务会在大量图上进行训练,如果在训练模型时每次只使用一个图,效率会非常低。借鉴常见的深度学习实践中的小批量训练思想,可以构建一个包含多个图的批次,并将它们一起发送进行一次训练迭代。
在DGL中,可以从一系列图中构建一个单一的批处理图。这个批处理图可以简单地用作一个单一的大图,其中连接的组件对应于原始的小图。

批量图
以下示例在图形列表上调用dgl.batch()
。
批处理图是一个单一的图,同时它也携带了关于列表的信息。
import dgl
import torch as th
g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
bg = dgl.batch([g1, g2])
bg
# Graph(num_nodes=7, num_edges=7,
# ndata_schemes={}
# edata_schemes={})
bg.batch_size
# 2
bg.batch_num_nodes()
# tensor([4, 3])
bg.batch_num_edges()
# tensor([3, 4])
bg.edges()
# (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))
请注意,大多数dgl转换函数会丢弃批次信息。
为了保持这些信息,请在转换后的图上使用dgl.DGLGraph.set_batch_num_nodes()
和dgl.DGLGraph.set_batch_num_edges()
。
图读取
数据中的每个图可能都有其独特的结构,以及其节点和边的特征。为了做出单一的预测,通常需要对可能丰富的信息进行聚合和总结。这种操作被称为readout。常见的readout操作包括对所有节点或边特征进行求和、平均、最大值或最小值。
给定一个图 \(g\),可以定义平均节点特征读取为
其中 \(h_g\) 是 \(g\) 的表示,\(\mathcal{V}\) 是 \(g\) 中的节点集合,\(h_v\) 是节点 \(v\) 的特征。
DGL 提供了对常见读取操作的内置支持。例如,
dgl.mean_nodes()
实现了上述读取操作。
一旦\(h_g\)可用,可以将其通过一个MLP层进行分类输出。
编写神经网络模型
模型的输入是带有节点和边特征的批量图。
批量图上的计算
首先,批次中的不同图是完全分离的,即任何两个图之间没有边。有了这个良好的特性,所有消息传递函数仍然具有相同的结果。
其次,批量图上的读取函数将分别对每个图执行。假设批量大小为\(B\),并且要聚合的特征具有维度\(D\),则读取结果的形状将为\((B, D)\)。
import dgl
import torch
g1 = dgl.graph(([0, 1], [1, 0]))
g1.ndata['h'] = torch.tensor([1., 2.])
g2 = dgl.graph(([0, 1], [1, 2]))
g2.ndata['h'] = torch.tensor([1., 2., 3.])
dgl.readout_nodes(g1, 'h')
# tensor([3.]) # 1 + 2
bg = dgl.batch([g1, g2])
dgl.readout_nodes(bg, 'h')
# tensor([3., 6.]) # [1 + 2, 1 + 2 + 3]
最后,批处理图中的每个节点/边特征是通过按顺序连接所有图中的相应特征获得的。
bg.ndata['h']
# tensor([1., 2., 1., 2., 3.])
模型定义
了解上述计算规则后,可以如下定义一个模型。
import dgl.nn.pytorch as dglnn
import torch.nn as nn
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g, h):
# Apply graph convolution and activation.
h = F.relu(self.conv1(g, h))
h = F.relu(self.conv2(g, h))
with g.local_scope():
g.ndata['h'] = h
# Calculate graph representation by average readout.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
训练循环
数据加载
一旦模型定义完成,就可以开始训练。由于图分类处理的是许多相对较小的图,而不是一个大的单一图,因此可以在图的随机小批量上进行高效训练,而无需设计复杂的图采样算法。
假设有一个图分类数据集,如第4章:图数据管道中介绍的那样。
import dgl.data
dataset = dgl.data.GINDataset('MUTAG', False)
图分类数据集中的每一项都是一对图和其标签。通过利用GraphDataLoader来迭代小批量图数据集,可以加快数据加载过程。
from dgl.dataloading import GraphDataLoader
dataloader = GraphDataLoader(
dataset,
batch_size=1024,
drop_last=False,
shuffle=True)
训练循环简单地涉及遍历数据加载器并更新模型。
import torch.nn.functional as F
# Only an example, 7 is the input feature size
model = Classifier(7, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
for batched_graph, labels in dataloader:
feats = batched_graph.ndata['attr']
logits = model(batched_graph, feats)
loss = F.cross_entropy(logits, labels)
opt.zero_grad()
loss.backward()
opt.step()
关于图分类的端到端示例,请参见
DGL的GIN示例。
训练循环位于
main.py
中的train
函数内。
模型实现位于
gin.py,
其中包含更多组件,例如使用
dgl.nn.pytorch.GINConv
(也可在MXNet和Tensorflow中使用)
作为图卷积层、批量归一化等。
异构图
使用异构图进行图分类与使用同构图进行图分类有些不同。除了与异构图兼容的图卷积模块外,还需要在读取函数中聚合不同类型的节点。
以下展示了如何对每种节点类型的节点表示求平均值的示例。
class RGCN(nn.Module):
def __init__(self, in_feats, hid_feats, out_feats, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(in_feats, hid_feats)
for rel in rel_names}, aggregate='sum')
self.conv2 = dglnn.HeteroGraphConv({
rel: dglnn.GraphConv(hid_feats, out_feats)
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs is features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
return h
class HeteroClassifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
super().__init__()
self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
h = g.ndata['feat']
h = self.rgcn(g, h)
with g.local_scope():
g.ndata['h'] = h
# Calculate graph representation by average readout.
hg = 0
for ntype in g.ntypes:
hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
return self.classify(hg)
其余代码与同构图代码没有区别。
# etypes is the list of edge types as strings.
model = HeteroClassifier(10, 20, 5, etypes)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
for batched_graph, labels in dataloader:
logits = model(batched_graph)
loss = F.cross_entropy(logits, labels)
opt.zero_grad()
loss.backward()
opt.step()