torch_geometric.nn.conv.DynamicEdgeConv
- class DynamicEdgeConv(nn: Callable, k: int, aggr: str = 'max', num_workers: int = 1, **kwargs)[source]
基础类:
MessagePassing来自“Dynamic Graph CNN for Learning on Point Clouds”论文的动态边缘卷积算子(参见
torch_geometric.nn.conv.EdgeConv),其中图是使用特征空间中的最近邻动态构建的。- Parameters:
nn (torch.nn.Module) – 一个神经网络 \(h_{\mathbf{\Theta}}\),它将成对连接的节点特征
x从形状 :obj:`[-1, 2 * in_channels] 映射到形状[-1, out_channels],例如由torch.nn.Sequential定义。k (int) – 最近邻的数量。
aggr (str, 可选) – 使用的聚合方案 (
"add","mean","max"). (默认:"max")num_workers (int) – 用于 k-NN 计算的工人数量。 如果
batch不是None,或者输入位于 GPU 上,则此参数无效。(默认值:1)**kwargs (可选) –
torch_geometric.nn.conv.MessagePassing的额外参数。
- Shapes:
输入: 节点特征 \((|\mathcal{V}|, F_{in})\) 或 \(((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))\) 如果是二分图, 批次向量 \((|\mathcal{V}|)\) 或 \(((|\mathcal{V}|), (|\mathcal{V}|))\) 如果是二分图 (可选)
输出: 节点特征 \((|\mathcal{V}|, F_{out})\) 或 \((|\mathcal{V}_t|, F_{out})\) 如果是二分图