6.4 实现自定义图采样器
实现自定义采样器需要继承
dgl.graphbolt.SubgraphSampler
基类并实现其抽象的
sample_subgraphs
方法。sample_subgraphs
方法应该
接收种子节点,这些节点是从中采样邻居的节点:
def sample_subgraphs(self, seed_nodes):
return input_nodes, sampled_subgraphs
该方法应返回输入节点ID列表和一个子图列表。每个子图是一个SampledSubgraph
对象。
在采样过程中所需的任何其他数据,如图结构、扇出大小等,应通过构造函数传递给采样器。
下面的代码实现了一个经典的邻居采样器:
@functional_datapipe("customized_sample_neighbor")
class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
def __init__(self, datapipe, graph, fanouts):
super().__init__(datapipe)
self.graph = graph
self.fanouts = fanouts
def sample_subgraphs(self, seed_nodes):
subgs = []
for fanout in reversed(self.fanouts):
# Sample a fixed number of neighbors of the current seed nodes.
input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
subgs.insert(0, sg)
seed_nodes = input_nodes
return input_nodes, subgs
要将此采样器与DataLoader
一起使用:
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
for data in dataloader:
input_features = data.node_features["feat"]
output_labels = data.labels
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
异构图采样器
要为异构图编写采样器,需要注意graph参数是一个异构图,而seeds可能是一个ID张量的字典。DGL的大多数图采样操作符(例如,上述示例中的sample_neighbors
和to_block
函数)可以原生地处理异构图,因此许多采样器自动适用于异构图。例如,上述的CustomizedNeighborSampler
可以用于异构图:
import dgl.graphbolt as gb
hg = gb.FusedCSCSamplingGraph()
train_set = item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seeds", "labels"),
),
}
)
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
for data in dataloader:
input_features = {
ntype: data.node_features[(ntype, "feat")]
for ntype in data.blocks[0].srctypes
}
output_labels = data.labels["user"]
output_predictions = model(data.blocks, input_features)["user"]
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
采样后排除边
在某些情况下,我们可能希望从采样的子图中排除种子边。例如,在链接预测任务中,我们希望从采样的子图中排除训练集中的边,以防止信息泄露。为此,我们需要在采样后立即添加一个额外的数据管道,如下所示:
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
请查看exclude_seed_edges()
的API页面以获取更多详细信息。
上述API基于exclude_edges()
。
如果你想根据其他标准从采样的子图中排除边,你可以编写自己的转换函数。请参考该方法。
你也可以参考 Link Prediction中的示例。