dgl.sampling.PinSAGESampler
- class dgl.sampling.PinSAGESampler(G, ntype, other_type, num_traversals, termination_prob, num_random_walks, num_neighbors, weight_column='weights')[source]
类似PinSAGE的邻居采样器。
这个可调用对象适用于具有边类型
(ntype, fwtype, other_type)
和(other_type, bwtype, ntype)
的双向二分图(其中ntype
、fwtype
、bwtype
和other_type
可以是任意类型名称)。它将生成一个节点类型为ntype
的同构图,其中每个给定节点的邻居是从该给定节点开始的多个随机游走中最常访问的同类型节点。每个随机游走由多个基于元路径的遍历组成,每次遍历后都有终止的概率。元路径始终为[fwtype, bwtype]
,从节点类型ntype
到节点类型other_type
,然后再回到ntype
。返回的齐次图的边将从最常访问的节点连接到给定的节点,并带有一个表示访问次数的特征。
此采样器支持UVA和GPU采样。 更多详情请参考6.8 使用GPU进行邻域采样。
- Parameters:
G (DGLGraph) –
双向二分图。
该图应仅包含两种节点类型:
ntype
和other_type
。 该图应仅包含两种边类型,一种从ntype
连接到other_type
,另一种从other_type
连接到ntype
。ntype (str) – 用于构建图的节点类型。
other_type (str) – 另一个节点类型。
num_traversals (int) –
单个随机游走中基于元路径的遍历的最大次数。
通常被视为超参数。
termination_prob (int) –
每次基于元路径的遍历后的终止概率。
通常被视为超参数。
num_random_walks (int) –
每个给定节点尝试的随机游走次数。
通常被视为超参数。
num_neighbors (int) – 为每个给定节点选择的邻居(或最常访问的节点)的数量。
weight_column (str, default "weights") – 要存储在返回图上的边特征的名称,该特征包含访问次数。
示例
生成一个具有3000个“A”节点和5000个“B”节点的随机双向二分图。
>>> g = scipy.sparse.random(3000, 5000, 0.003) >>> G = dgl.heterograph({ ... ('A', 'AB', 'B'): g.nonzero(), ... ('B', 'BA', 'A'): g.T.nonzero()})
然后我们创建一个PinSage邻居采样器,它采样一个节点类型为“A”的图。每个节点最多有10个邻居。
>>> sampler = dgl.sampling.PinSAGESampler(G, 'A', 'B', 3, 0.5, 200, 10)
这是我们根据PinSAGE算法为类型“A”的节点#0、#1和#2选择邻居的方式:
>>> seeds = torch.LongTensor([0, 1, 2]) >>> frontier = sampler(seeds) >>> frontier.all_edges(form='uv') (tensor([ 230, 0, 802, 47, 50, 1639, 1533, 406, 2110, 2687, 2408, 2823, 0, 972, 1230, 1658, 2373, 1289, 1745, 2918, 1818, 1951, 1191, 1089, 1282, 566, 2541, 1505, 1022, 812]), tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]))
有关PinSAGE模型的端到端示例,包括在多层上进行采样并使用采样图进行计算,请参考我们在
examples/pytorch/pinsage
中的PinSage示例。参考文献
- Graph Convolutional Neural Networks for Web-Scale Recommender Systems
Ying 等人,2018年,https://arxiv.org/abs/1806.01973
- __init__(G, ntype, other_type, num_traversals, termination_prob, num_random_walks, num_neighbors, weight_column='weights')[source]
方法
__init__
(G, ntype, other_type, ...[, ...])