dgl.sampling.sample_neighbors
- dgl.sampling.sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False, copy_ndata=True, copy_edata=True, _dist_training=False, exclude_edges=None, output_device=None)[source]
对给定节点的邻近边进行采样并返回诱导子图。
对于每个节点,将随机选择一定数量的入站(或当
edge_dir == 'out'
时为出站)边。返回的图将包含原始图中的所有节点,但仅包含采样的边。Node/edge features are not preserved. The original IDs of the sampled edges are stored as the dgl.EID feature in the returned graph.
此函数支持GPU采样。更多详情请参阅6.8 使用GPU进行邻域采样。
- Parameters:
g (DGLGraph) – 图。可以在CPU或GPU上。
nodes (tensor 或 dict) –
从中采样邻居的节点ID。
此参数可以接受单个ID张量或节点类型和ID张量的字典。 如果给出单个张量,则图必须只有一种类型的节点。
fanout (int 或 dict[etype, int]) –
每种边类型上每个节点要采样的边数。
此参数可以接受一个整数或一个边类型和整数的字典。 如果给定一个整数,DGL将为每种边类型的每个节点采样此数量的边。
如果为单个边类型给定-1,将选择该边类型且概率不为零的所有相邻边。
edge_dir (str, optional) –
Determines whether to sample inbound or outbound edges.
Can take either
in
for inbound edges orout
for outbound edges.prob (str, optional) –
用作节点每个相邻边的(未归一化)概率的特征名称。该特征必须为每条边只有一个元素。
特征必须是非负浮点数或布尔值。否则,结果将是未定义的。
exclude_edges (tensor 或 dict) –
在为种子节点采样邻居时要排除的边ID。
此参数可以接受单个ID张量或边类型和ID张量的字典。 如果给出单个张量,则图必须只有一种类型的节点。
replace (bool, optional) – 如果为True,则进行有放回的抽样。
copy_ndata (bool, optional) –
如果为True,新图的节点特征将从原始图中复制。如果为False,新图将不会有任何节点特征。
(默认值: True)
copy_edata (bool, optional) –
如果为True,新图的边特征将从原始图中复制。如果为False,新图将不会有任何边特征。
(默认值: True)
_dist_training (bool, optional) –
内部参数。请勿使用。
(默认值: False)
output_device (Framework-specific device context object, optional) – The output device. Default is the same as the input graph.
- Returns:
一个仅包含采样邻边的采样子图。
- Return type:
注释
如果
copy_ndata
或copy_edata
为 True,则相同的张量将用作原始图和新图的节点或边特征。因此,用户应避免对新图的节点特征进行原地操作,以避免特征损坏。示例
假设你有以下图表
>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]))
以及权重
>>> g.edata['prob'] = torch.FloatTensor([0., 1., 0., 1., 0., 1.])
为节点0和节点1采样一个入边:
>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1) >>> sg.edges(order='eid') (tensor([1, 0]), tensor([0, 1])) >>> sg.edata[dgl.EID] tensor([2, 0])
为节点0和节点1采样一个入边,概率在边特征
prob
中:>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 1, prob='prob') >>> sg.edges(order='eid') (tensor([2, 1]), tensor([0, 1]))
当
fanout
大于实际邻居数量且不进行替换时, DGL将选择所有邻居:>>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3) >>> sg.edges(order='eid') (tensor([1, 2, 0, 1]), tensor([0, 0, 1, 1]))
在种子节点的采样过程中排除某些EID:
>>> g = dgl.graph(([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])) >>> g_edges = g.all_edges(form='all')`` (tensor([0, 0, 1, 1, 2, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5])) >>> sg = dgl.sampling.sample_neighbors(g, [0, 1], 3, exclude_edges=[0, 1, 2]) >>> sg.all_edges(form='all') (tensor([2, 1]), tensor([0, 1]), tensor([0, 1])) >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3]) tensor([False, False, False]) >>> g = dgl.heterograph({ ... ('drug', 'interacts', 'drug'): ([0, 0, 1, 1, 3, 2], [1, 2, 0, 1, 2, 0]), ... ('drug', 'interacts', 'gene'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0]), ... ('drug', 'treats', 'disease'): ([0, 0, 1, 1, 2, 2], [1, 2, 0, 1, 2, 0])}) >>> g_edges = g.all_edges(form='all', etype=('drug', 'interacts', 'drug')) (tensor([0, 0, 1, 1, 3, 2]), tensor([1, 2, 0, 1, 2, 0]), tensor([0, 1, 2, 3, 4, 5])) >>> excluded_edges = {('drug', 'interacts', 'drug'): g_edges[2][:3]} >>> sg = dgl.sampling.sample_neighbors(g, {'drug':[0, 1]}, 3, exclude_edges=excluded_edges) >>> sg.all_edges(form='all', etype=('drug', 'interacts', 'drug')) (tensor([2, 1]), tensor([0, 1]), tensor([0, 1])) >>> sg.has_edges_between(g_edges[0][:3],g_edges[1][:3],etype=('drug', 'interacts', 'drug')) tensor([False, False, False])