torch_geometric.nn.conv.GravNetConv

class GravNetConv(in_channels: int, out_channels: int, space_dimensions: int, propagate_dimensions: int, k: int, num_workers: Optional[int] = None, **kwargs)[source]

Bases: MessagePassing

来自“使用距离加权图网络学习不规则粒子探测器几何的表示”论文的GravNet操作符,其中图是使用最近邻动态构建的。邻居是在特征空间的可学习低维投影中构建的。然后,输入特征空间的第二个投影从邻居传播到每个顶点,使用通过将高斯函数应用于距离得出的距离权重。

Parameters:
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – 输出通道的数量。

  • space_dimensions (int) – 用于构建邻居的空间的维度;在论文中称为 \(S\)

  • propagate_dimensions (int) – 要在顶点之间传播的特征数量;在论文中称为 \(F_{\textrm{LR}}\)

  • k (int) – 最近邻居的数量。

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F_{in})\)\(((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))\) 如果是二分图, 批次向量 \((|\mathcal{V}|)\)\(((|\mathcal{V}_s|), (|\mathcal{V}_t|))\) 如果是二分图 (可选)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

forward(x: Union[Tensor, Tuple[Tensor, Tensor]], batch: Union[Tensor, None, Tuple[Tensor, Tensor]] = None) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。