半径图
- class dgl.nn.pytorch.factory.RadiusGraph(r, p=2, self_loop=False, compute_mode='donot_use_mm_for_euclid_dist')[source]
Bases:
Module
将一组点转换为在给定距离内具有邻居的双向图的层。
RadiusGraph 的实现步骤如下:
计算所有点的NxN成对距离矩阵。
选择距离每个点在一定范围内的点作为它们的邻居。
构建一个图,其中每个点作为节点与其邻居相连。
返回的图的节点对应于点,其中每个点的邻居在给定距离内。
- Parameters:
r (float) – Radius of the neighbors.
p (float, optional) –
Power parameter for the Minkowski metric. When
p = 1
it is the equivalent of Manhattan distance (L1 norm) and Euclidean distance (L2 norm) forp = 2
.(default: 2)
self_loop (bool, optional) –
Whether the radius graph will contain self-loops.
(default: False)
compute_mode (str, optional) –
use_mm_for_euclid_dist_if_necessary
- will use matrix multiplication approach to calculate euclidean distance (p = 2) if P > 25 or R > 25use_mm_for_euclid_dist
- will always use matrix multiplication approach to calculate euclidean distance (p = 2)donot_use_mm_for_euclid_dist
- will never use matrix multiplication approach to calculate euclidean distance (p = 2).(default: donot_use_mm_for_euclid_dist)
示例
以下示例使用 PyTorch 后端。
>>> import dgl >>> from dgl.nn.pytorch.factory import RadiusGraph
>>> x = torch.tensor([[0.0, 0.0, 1.0], ... [1.0, 0.5, 0.5], ... [0.5, 0.2, 0.2], ... [0.3, 0.2, 0.4]]) >>> rg = RadiusGraph(0.75) >>> g = rg(x) # Each node has neighbors within 0.75 distance >>> g.edges() (tensor([0, 1, 2, 2, 3, 3]), tensor([3, 2, 1, 3, 0, 2]))
当
get_distances
为 True 时,前向传递返回半径图和对应边的距离。>>> x = torch.tensor([[0.0, 0.0, 1.0], ... [1.0, 0.5, 0.5], ... [0.5, 0.2, 0.2], ... [0.3, 0.2, 0.4]]) >>> rg = RadiusGraph(0.75) >>> g, dist = rg(x, get_distances=True) >>> g.edges() (tensor([0, 1, 2, 2, 3, 3]), tensor([3, 2, 1, 3, 0, 2])) >>> dist tensor([[0.7000], [0.6557], [0.6557], [0.2828], [0.7000], [0.2828]])
- forward(x, get_distances=False)[source]
前向计算。
- Parameters:
x (Tensor) – 点的坐标。 \((N, D)\) 其中 \(N\) 表示 点集中点的数量,\(D\) 表示 特征的大小。它可以在CPU或GPU上。点坐标的设备 指定了半径图的设备。
get_distances (bool, optional) –
Whether to return the distances for the corresponding edges in the radius graph.
(default: False)
- Returns:
DGLGraph – 构建的图。节点ID的顺序与
x
相同。torch.Tensor, optional – 构建图中边的距离。距离的顺序与边ID的顺序相同。