子图采样器
- class dgl.graphbolt.SubgraphSampler(datapipe, *args, **kwargs)[source]
-
用于从较大图中的给定节点集中采样子图的子图采样器。
功能名称:
sample_subgraph
.这个类是所有子图采样器的基类。任何SubgraphSampler的子类都应该实现
sample_subgraphs()
方法 或者sampling_stages()
方法来定义细粒度的采样阶段,以利用GraphBolt DataLoader提供的优化。- Parameters:
datapipe (DataPipe) – The datapipe.
args (非关键字参数) – 传递给 sampling_stages 的参数。
kwargs (关键字参数) – 传递给 sampling_stages 的参数。
- sample_subgraphs(seeds, seeds_timestamp)[source]
从给定的种子中采样子图,可能带有时间约束。
SubgraphSampler的任何子类都应实现此方法。
- Parameters:
- Returns:
Union[torch.Tensor, Dict[str, torch.Tensor]] – 输入节点。
List[SampledSubgraph] – 采样的子图。
示例
>>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs
- sampling_stages(datapipe)[source]
采样阶段通过链接到数据管道来定义。默认实现期望
sample_subgraphs()
被实现。要定义细粒度的阶段,应重写此方法。