子图采样器

class dgl.graphbolt.SubgraphSampler(datapipe, *args, **kwargs)[source]

基础类:MiniBatchTransformer

用于从较大图中的给定节点集中采样子图的子图采样器。

功能名称: sample_subgraph.

这个类是所有子图采样器的基类。任何SubgraphSampler的子类都应该实现sample_subgraphs()方法 或者sampling_stages()方法来定义细粒度的采样阶段,以利用GraphBolt DataLoader提供的优化。

Parameters:
  • datapipe (DataPipe) – The datapipe.

  • args (非关键字参数) – 传递给 sampling_stages 的参数。

  • kwargs (关键字参数) – 传递给 sampling_stages 的参数。

sample_subgraphs(seeds, seeds_timestamp)[source]

从给定的种子中采样子图,可能带有时间约束。

SubgraphSampler的任何子类都应实现此方法。

Parameters:
  • seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 种子节点。

  • seeds_timestamp (Union[torch.Tensor, Dict[str, torch.Tensor]]) – 种子节点的时间戳。如果提供,采样的子图不应包含任何比种子节点时间戳更新的节点或边。默认值:None。

Returns:

  • Union[torch.Tensor, Dict[str, torch.Tensor]] – 输入节点。

  • List[SampledSubgraph] – 采样的子图。

示例

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs
sampling_stages(datapipe)[source]

采样阶段通过链接到数据管道来定义。默认实现期望sample_subgraphs()被实现。要定义细粒度的阶段,应重写此方法。