6.2 使用邻域采样训练GNN进行边分类
边缘分类/回归的训练与节点分类/回归的训练有些相似,但有几个显著的区别。
Define a neighborhood sampler and data loader
你可以使用 与节点分类相同的邻居采样器。
datapipe = datapipe.sample_neighbor(g, [10, 10])
# Or equivalently
datapipe = dgl.graphbolt.NeighborSampler(datapipe, g, [10, 10])
定义数据加载器的代码与节点分类的代码相同。唯一的区别是它在训练集中迭代的是边(即节点对)而不是节点。
import dgl.graphbolt as gb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
遍历DataLoader将产生MiniBatch,其中包含一系列特别创建的图,这些图表示每层上的计算依赖关系。您可以通过mini_batch.blocks访问消息流图(MFGs)。
注意
请参阅:doc:`随机训练教程 <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__ 了解消息流图的概念。
If you wish to develop your own neighborhood sampler or you want a more detailed explanation of the concept of MFGs, please refer to 6.4 Implementing Custom Graph Samplers.
从原始图中移除小批量中的边以进行邻居采样
在训练边分类模型时,有时你希望从计算依赖中移除训练数据中出现的边,就好像它们从未存在过一样。否则,模型将“知道”两个节点之间存在边的事实,并可能利用这一点来获得优势。
因此,在边分类中,有时你可能希望从采样的minibatch中排除种子边及其反向边。你可以使用exclude_seed_edges()与MiniBatchTransformer一起实现这一点。
import dgl.graphbolt as gb
from functools import partial
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
datapipe = datapipe.transform(exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
Adapt your model for minibatch training
边缘分类模型通常由两部分组成:
获取事件节点表示的一部分。
计算来自入射节点表示的边缘得分的另一部分。
前一部分与 节点分类中的部分 完全相同,我们可以直接重用。输入仍然是DGL提供的数据加载器生成的MFG列表,以及输入特征。
class StochasticTwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.conv1 = dglnn.GraphConv(in_features, hidden_features)
self.conv2 = dglnn.GraphConv(hidden_features, out_features)
def forward(self, blocks, x):
x = F.relu(self.conv1(blocks[0], x))
x = F.relu(self.conv2(blocks[1], x))
return x
后一部分的输入通常是前一部分的输出,以及由小批量中的边引起的原始图的子图(节点对)。该子图由相同的数据加载器生成。
以下代码展示了通过连接相邻节点特征并使用密集层进行投影来预测边上分数的示例。
class ScorePredictor(nn.Module):
def __init__(self, num_classes, in_features):
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, seeds, x):
src_x = x[seeds[:, 0]]
dst_x = x[seeds[:, 1]]
data = torch.cat([src_x, dst_x], 1)
return self.W(data)
整个模型将采用由数据加载器生成的MFGs列表和边,以及如下所示的输入节点特征:
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_classes):
super().__init__()
self.gcn = StochasticTwoLayerGCN(
in_features, hidden_features, out_features)
self.predictor = ScorePredictor(num_classes, out_features)
def forward(self, blocks, x, seeds):
x = self.gcn(blocks, x)
return self.predictor(seeds, x)
DGL 确保边子图中的节点与生成的 MFG 列表中最后一个 MFG 的输出节点相同。
Training Loop
训练循环与节点分类非常相似。您可以遍历数据加载器,并获取由小批量中的边引起的子图,以及计算其相关节点表示所需的MFG列表。
import torch.nn.functional as F
model = Model(in_features, hidden_features, out_features, num_classes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
blocks = data.blocks
x = data.edge_features("feat")
y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
opt.step()
For heterogeneous graphs
计算异质图上节点表示的模型也可以用于计算边分类/回归的事件节点表示。
class StochasticTwoLayerRGCN(nn.Module):
def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
for rel in rel_names
})
self.conv2 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
for rel in rel_names
})
def forward(self, blocks, x):
x = self.conv1(blocks[0], x)
x = self.conv2(blocks[1], x)
return x
对于分数预测,同构图和异构图之间的唯一实现区别在于我们正在遍历边类型。
class ScorePredictor(nn.Module):
def __init__(self, num_classes, in_features):
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def forward(self, seeds, x):
scores = {}
for etype in seeds.keys():
src, dst = seeds[etype].T
data = torch.cat([x[etype][src], x[etype][dst]], 1)
scores[etype] = self.W(data)
return scores
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_classes,
etypes):
super().__init__()
self.rgcn = StochasticTwoLayerRGCN(
in_features, hidden_features, out_features, etypes)
self.pred = ScorePredictor(num_classes, out_features)
def forward(self, seeds, blocks, x):
x = self.rgcn(blocks, x)
return self.pred(seeds, x)
数据加载器的定义与同构图几乎相同。唯一的区别是,现在的train_set是ItemSetDict的实例,而不是ItemSet。
import dgl.graphbolt as gb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
seeds = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 3, (1000,))
seeds_labels = {
"user:like:item": gb.ItemSet(
(seeds, labels), names=("seeds", "labels")
),
"user:follow:user": gb.ItemSet(
(seeds, labels), names=("seeds", "labels")
),
}
train_set = gb.ItemSetDict(seeds_labels)
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"item": ["feat"], "user": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
如果你想在异质图上排除反向边,情况会有所不同。在异质图上,反向边通常具有与边本身不同的边类型,以便区分“正向”和“反向”关系(例如,follow 和 followed_by 是彼此的反向关系,like 和 liked_by 是彼此的反向关系,等等)。
如果一种类型中的每条边在另一种类型中都有一个具有相同ID的反向边,您可以指定边类型及其反向类型之间的映射。然后,排除小批量中的边及其反向边的方法如下。
exclude_seed_edges = partial(
gb.exclude_seed_edges,
include_reverse_edges=True,
reverse_etypes_mapping={
"user:like:item": "item:liked_by:user",
"user:follow:user": "user:followed_by:user",
},
)
datapipe = datapipe.transform(exclude_seed_edges)
训练循环再次与同构图上的几乎相同,除了compute_loss的实现,这里将接受两个节点类型和预测的字典。
import torch.nn.functional as F
model = Model(in_features, hidden_features, out_features, num_classes, etypes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
blocks = data.blocks
x = data.edge_features(("user:like:item", "feat"))
y_hat = model(data.blocks, x, data.compacted_seeds)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
opt.step()