torch_geometric.nn.conv.GravNetConv
- class GravNetConv(in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: Optional[int] = None, **kwargs)[source]
Bases:
MessagePassing来自“使用距离加权图网络学习不规则粒子探测器几何的表示”论文的GravNet操作符,其中图是使用最近邻动态构建的。邻居是在特征空间的可学习低维投影中构建的。然后,输入特征空间的第二个投影从邻居传播到每个顶点,使用通过将高斯函数应用于距离得出的距离权重。
- Parameters:
in_channels (int) – Size of each input sample, or
-1to derive the size from the first input(s) to the forward method.out_channels (int) – 输出通道的数量。
space_dimensions (int) – 用于构建邻居的空间的维度;在论文中称为 \(S\)。
propagate_dimensions (int) – 要在顶点之间传播的特征数量;在论文中称为 \(F_{\textrm{LR}}\)。
k (int) – 最近邻居的数量。
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- Shapes:
输入: 节点特征 \((|\mathcal{V}|, F_{in})\) 或 \(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) 如果是二分图, 批次向量 \((|\mathcal{V}|)\) 或 \(((|\mathcal{V}_s|), (|\mathcal{V}_t|))\) 如果是二分图 (可选)
output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite