邻居采样器

class dgl.graphbolt.NeighborSampler(datapipe, graph, fanouts, replace=False, prob_name=None, deduplicate=True)[source]

基础类:NeighborSamplerImpl

从图中采样邻居边并返回一个子图。

功能名称: sample_neighbor.

邻居采样器负责从给定数据中采样一个子图。它返回一个诱导子图以及压缩信息。在节点分类任务的上下文中,邻居采样器直接使用提供的节点作为种子节点。然而,在涉及链接预测的场景中,该过程需要另一个预处理操作。即,从给定的节点对中收集唯一节点,包括正节点对和负节点对,并将这些节点用作后续步骤的种子节点。

Parameters:
  • datapipe (DataPipe) – The datapipe.

  • graph (FusedCSCSamplingGraph) – 用于执行子图采样的图。

  • fanouts (list[torch.Tensor] 或 list[int]) – 每个节点要采样的边数,考虑或不考虑边类型。此参数的长度隐式表示正在进行的采样层。 注意:fanout的顺序是从最外层到最内层。 例如,fanout ‘[15, 10, 5]’ 表示15对应最外层,10对应中间层,5对应最内层。

  • replace (bool) – 布尔值,指示样本是否是有放回进行的。如果为True,一个值可以被多次选择。否则,每个值只能被选择一次。

  • prob_name (str, optional) – 用作每个节点采样权重的边属性的名称。此属性张量应包含与节点的每个相邻边对应的(未归一化的)概率。它必须是一个1D浮点或布尔张量,元素数量等于边的总数。

  • deduplicate (bool) – 布尔值,指示是否会在跳转之间对种子进行去重。 如果为True,种子中的相同元素将被删除,只保留一个。 否则,相同的元素将保留。

示例

>>> import torch
>>> import dgl.graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> seeds = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(seeds, names="seeds")
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
>>> next(iter(datapipe)).sampled_subgraphs
[SampledSubgraphImpl(sampled_csc=CSCFormatBase(
        indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
        indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
    ),
    original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
    original_edge_ids=None,
    original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
        indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
        indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
    ),
    original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
    original_edge_ids=None,
    original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
        indptr=tensor([0, 2, 4, 5, 6]),
        indices=tensor([1, 4, 0, 5, 5, 3]),
    ),
    original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
    original_edge_ids=None,
    original_column_node_ids=tensor([0, 1, 4, 5]),
)]