torch_geometric.nn.conv.DynamicEdgeConv

class DynamicEdgeConv(nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs)[source]

基础类: MessagePassing

来自“Dynamic Graph CNN for Learning on Point Clouds”论文的动态边缘卷积算子(参见torch_geometric.nn.conv.EdgeConv),其中图是使用特征空间中的最近邻动态构建的。

Parameters:
  • nn (torch.nn.Module) – 一个神经网络 \(h_{\mathbf{\Theta}}\),它将成对连接的节点特征 x 从形状 :obj:`[-1, 2 * in_channels] 映射到形状 [-1, out_channels],例如由 torch.nn.Sequential 定义。

  • k (int) – 最近邻的数量。

  • aggr (str, 可选) – 使用的聚合方案 ("add", "mean", "max"). (默认: "max")

  • num_workers (int) – 用于 k-NN 计算的工人数量。 如果 batch 不是 None,或者输入位于 GPU 上,则此参数无效。(默认值:1

  • **kwargs (可选) – torch_geometric.nn.conv.MessagePassing 的额外参数。

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F_{in})\)\(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) 如果是二分图, 批次向量 \((|\mathcal{V}|)\)\(((|\mathcal{V}|), (|\mathcal{V}|))\) 如果是二分图 (可选)

  • 输出: 节点特征 \((|\mathcal{V}|, F_{out})\)\((|\mathcal{V}_t|, F_{out})\) 如果是二分图

forward(x: Union[Tensor, Tuple[Tensor, Tensor]], batch: Union[Tensor, None, Tuple[Tensor, Tensor]] = None) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。