InSubgraphSampler

class dgl.graphbolt.InSubgraphSampler(datapipe, graph)[source]

基础类:SubgraphSampler

对给定节点的入边所诱导的子图进行采样。

功能名称: sample_in_subgraph.

子图采样器负责从给定数据中采样一个子图,返回一个诱导子图以及压缩信息。

Parameters:
  • datapipe (DataPipe) – The datapipe.

  • graph (FusedCSCSamplingGraph) – 用于执行in_subgraph采样的图。

示例

>>> import dgl.graphbolt as gb
>>> import torch
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seeds")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
>>> for _, data in enumerate(insubgraph_sampler):
...     print(data.sampled_subgraphs[0].sampled_csc)
...     print(data.sampled_subgraphs[0].original_row_node_ids)
...     print(data.sampled_subgraphs[0].original_column_node_ids)
CSCFormatBase(indptr=tensor([0, 3, 5]),
            indices=tensor([0, 1, 2, 3, 4]),
)
tensor([0, 1, 4, 2, 3])
tensor([0, 1])
CSCFormatBase(indptr=tensor([0, 2, 4]),
            indices=tensor([2, 3, 4, 0]),
)
tensor([2, 3, 0, 5, 1])
tensor([2, 3])
CSCFormatBase(indptr=tensor([0, 3, 5]),
            indices=tensor([2, 3, 1, 4, 0]),
)
tensor([4, 5, 0, 3, 1])
tensor([4, 5])
sample_subgraphs(seeds, seeds_timestamp)[source]

从给定的种子中采样子图,可能带有时间约束。

SubgraphSampler的任何子类都应实现此方法。

Parameters:
  • seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The seed nodes.

  • seeds_timestamp (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The timestamps of the seed nodes. If given, the sampled subgraphs should not contain any nodes or edges that are newer than the timestamps of the seed nodes. Default: None.

Returns:

  • Union[torch.Tensor, Dict[str, torch.Tensor]] – The input nodes.

  • List[SampledSubgraph] – The sampled subgraphs.

示例

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs