6.4 实现自定义图采样器

实现自定义采样器需要继承 dgl.graphbolt.SubgraphSampler 基类并实现其抽象的 sample_subgraphs 方法。sample_subgraphs 方法应该 接收种子节点,这些节点是从中采样邻居的节点:

def sample_subgraphs(self, seed_nodes):
    return input_nodes, sampled_subgraphs

该方法应返回输入节点ID列表和一个子图列表。每个子图是一个SampledSubgraph对象。

在采样过程中所需的任何其他数据,如图结构、扇出大小等,应通过构造函数传递给采样器。

下面的代码实现了一个经典的邻居采样器:

@functional_datapipe("customized_sample_neighbor")
class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
   def __init__(self, datapipe, graph, fanouts):
       super().__init__(datapipe)
       self.graph = graph
       self.fanouts = fanouts

   def sample_subgraphs(self, seed_nodes):
       subgs = []
       for fanout in reversed(self.fanouts):
           # Sample a fixed number of neighbors of the current seed nodes.
           input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
           subgs.insert(0, sg)
           seed_nodes = input_nodes
       return input_nodes, subgs

要将此采样器与DataLoader一起使用:

datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

for data in dataloader:
    input_features = data.node_features["feat"]
    output_labels = data.labels
    output_predictions = model(data.blocks, input_features)
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

异构图采样器

要为异构图编写采样器,需要注意graph参数是一个异构图,而seeds可能是一个ID张量的字典。DGL的大多数图采样操作符(例如,上述示例中的sample_neighborsto_block函数)可以原生地处理异构图,因此许多采样器自动适用于异构图。例如,上述的CustomizedNeighborSampler可以用于异构图:

import dgl.graphbolt as gb
hg = gb.FusedCSCSamplingGraph()
train_set = item_set = gb.ItemSetDict(
    {
        "user": gb.ItemSet(
            (torch.arange(0, 5), torch.arange(5, 10)),
            names=("seeds", "labels"),
        ),
        "item": gb.ItemSet(
            (torch.arange(5, 10), torch.arange(10, 15)),
            names=("seeds", "labels"),
        ),
    }
)
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
    feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

for data in dataloader:
    input_features = {
        ntype: data.node_features[(ntype, "feat")]
        for ntype in data.blocks[0].srctypes
    }
    output_labels = data.labels["user"]
    output_predictions = model(data.blocks, input_features)["user"]
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

采样后排除边

在某些情况下,我们可能希望从采样的子图中排除种子边。例如,在链接预测任务中,我们希望从采样的子图中排除训练集中的边,以防止信息泄露。为此,我们需要在采样后立即添加一个额外的数据管道,如下所示:

datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)

请查看exclude_seed_edges()的API页面以获取更多详细信息。

上述API基于exclude_edges()。 如果你想根据其他标准从采样的子图中排除边,你可以编写自己的转换函数。请参考该方法。

你也可以参考 Link Prediction中的示例。