5.4 图分类

(中文版)

有时候,数据可能以多个图的形式存在,而不是一个大的单一图,例如不同类型的人群社区列表。通过用图来描述同一社区中人们之间的友谊关系,可以得到一个图列表进行分类。在这种情况下,图分类模型可以帮助识别社区的类型,即根据结构和整体信息对每个图进行分类。

概述

图分类与节点分类或链接预测的主要区别在于,预测结果表征了整个输入图的属性。可以像之前的任务一样在节点/边上执行消息传递,但也需要检索图级别的表示。

图分类流程如下:

Graph Classification Process

图分类过程

从左到右,常见的做法是:

  • 准备一批图表

  • 在批处理图上执行消息传递以更新节点/边特征

  • 将节点/边的特征聚合为图级别的表示

  • 基于图级表示对图进行分类

图批次

通常,图分类任务会在大量图上进行训练,如果在训练模型时每次只使用一个图,效率会非常低。借鉴常见的深度学习实践中的小批量训练思想,可以构建一个包含多个图的批次,并将它们一起发送进行一次训练迭代。

在DGL中,可以从一系列图中构建一个单一的批处理图。这个批处理图可以简单地用作一个单一的大图,其中连接的组件对应于原始的小图。

Batched Graph

批量图

以下示例在图形列表上调用dgl.batch()。 批处理图是一个单一的图,同时它也携带了关于列表的信息。

import dgl
import torch as th

g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))

bg = dgl.batch([g1, g2])
bg
# Graph(num_nodes=7, num_edges=7,
#       ndata_schemes={}
#       edata_schemes={})
bg.batch_size
# 2
bg.batch_num_nodes()
# tensor([4, 3])
bg.batch_num_edges()
# tensor([3, 4])
bg.edges()
# (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))

请注意,大多数dgl转换函数会丢弃批次信息。 为了保持这些信息,请在转换后的图上使用dgl.DGLGraph.set_batch_num_nodes()dgl.DGLGraph.set_batch_num_edges()

图读取

数据中的每个图可能都有其独特的结构,以及其节点和边的特征。为了做出单一的预测,通常需要对可能丰富的信息进行聚合和总结。这种操作被称为readout。常见的readout操作包括对所有节点或边特征进行求和、平均、最大值或最小值。

给定一个图 \(g\),可以定义平均节点特征读取为

\[h_g = \frac{1}{|\mathcal{V}|}\sum_{v\in \mathcal{V}}h_v\]

其中 \(h_g\)\(g\) 的表示,\(\mathcal{V}\)\(g\) 中的节点集合,\(h_v\) 是节点 \(v\) 的特征。

DGL 提供了对常见读取操作的内置支持。例如, dgl.mean_nodes() 实现了上述读取操作。

一旦\(h_g\)可用,可以将其通过一个MLP层进行分类输出。

编写神经网络模型

模型的输入是带有节点和边特征的批量图。

批量图上的计算

首先,批次中的不同图是完全分离的,即任何两个图之间没有边。有了这个良好的特性,所有消息传递函数仍然具有相同的结果。

其次,批量图上的读取函数将分别对每个图执行。假设批量大小为\(B\),并且要聚合的特征具有维度\(D\),则读取结果的形状将为\((B, D)\)

import dgl
import torch

g1 = dgl.graph(([0, 1], [1, 0]))
g1.ndata['h'] = torch.tensor([1., 2.])
g2 = dgl.graph(([0, 1], [1, 2]))
g2.ndata['h'] = torch.tensor([1., 2., 3.])

dgl.readout_nodes(g1, 'h')
# tensor([3.])  # 1 + 2

bg = dgl.batch([g1, g2])
dgl.readout_nodes(bg, 'h')
# tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]

最后,批处理图中的每个节点/边特征是通过按顺序连接所有图中的相应特征获得的。

bg.ndata['h']
# tensor([1., 2., 1., 2., 3.])

模型定义

了解上述计算规则后,可以如下定义一个模型。

import dgl.nn.pytorch as dglnn
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
        self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, h):
        # Apply graph convolution and activation.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = dgl.mean_nodes(g, 'h')
            return self.classify(hg)

训练循环

数据加载

一旦模型定义完成,就可以开始训练。由于图分类处理的是许多相对较小的图,而不是一个大的单一图,因此可以在图的随机小批量上进行高效训练,而无需设计复杂的图采样算法。

假设有一个图分类数据集,如第4章:图数据管道中介绍的那样。

import dgl.data
dataset = dgl.data.GINDataset('MUTAG', False)

图分类数据集中的每一项都是一对图和其标签。通过利用GraphDataLoader来迭代小批量图数据集,可以加快数据加载过程。

from dgl.dataloading import GraphDataLoader
dataloader = GraphDataLoader(
    dataset,
    batch_size=1024,
    drop_last=False,
    shuffle=True)

训练循环简单地涉及遍历数据加载器并更新模型。

import torch.nn.functional as F

# Only an example, 7 is the input feature size
model = Classifier(7, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
    for batched_graph, labels in dataloader:
        feats = batched_graph.ndata['attr']
        logits = model(batched_graph, feats)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()

关于图分类的端到端示例,请参见 DGL的GIN示例。 训练循环位于 main.py 中的train函数内。 模型实现位于 gin.py, 其中包含更多组件,例如使用 dgl.nn.pytorch.GINConv(也可在MXNet和Tensorflow中使用) 作为图卷积层、批量归一化等。

异构图

使用异构图进行图分类与使用同构图进行图分类有些不同。除了与异构图兼容的图卷积模块外,还需要在读取函数中聚合不同类型的节点。

以下展示了如何对每种节点类型的节点表示求平均值的示例。

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs is features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()

        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)

其余代码与同构图代码没有区别。

# etypes is the list of edge types as strings.
model = HeteroClassifier(10, 20, 5, etypes)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
    for batched_graph, labels in dataloader:
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()