torch_geometric.nn.models.EdgeCNN
- class EdgeCNN(in_channels: int, hidden_channels: int, num_layers: int, out_channels: Optional[int] = None, dropout: float = 0.0, act: Optional[Union[str, Callable]] = 'relu', act_first: bool = False, act_kwargs: Optional[Dict[str, Any]] = None, norm: Optional[Union[str, Callable]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, jk: Optional[str] = None, **kwargs)[source]
基础类:
BasicGNN来自“Dynamic Graph CNN for Learning on Point Clouds”论文的图神经网络,使用
EdgeConv操作符进行消息传递。- Parameters:
in_channels (int) – 每个输入样本的大小。
hidden_channels (int) – 每个隐藏样本的大小。
num_layers (int) – 消息传递层的数量。
out_channels (int, optional) – 如果未设置为
None,将应用一个最终的线性变换,将隐藏节点嵌入转换为输出大小out_channels。(默认值:None)dropout (float, 可选) – 丢弃概率。 (默认:
0.)act (str 或 Callable, 可选) – 使用的非线性激活函数。(默认:
"relu")act_kwargs (Dict[str, Any], optional) – 传递给由
act定义的相应激活函数的参数。 (默认值:None)norm_kwargs (Dict[str, Any], optional) – 传递给由
norm定义的相应归一化函数的参数。 (默认值:None)jk (str, optional) – 跳跃知识模式。如果指定,模型将额外应用一个最终的线性变换,将节点嵌入转换为预期的输出特征维度。 (
None,"last","cat","max","lstm"). (默认:None)**kwargs (可选) –
torch_geometric.nn.conv.EdgeConv的额外参数。
- forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, batch_size: Optional[int] = None, num_sampled_nodes_per_hop: Optional[List[int]] = None, num_sampled_edges_per_hop: Optional[List[int]] = None) Tensor
前向传播。
- Parameters:
x (torch.Tensor) – 输入的节点特征。
edge_index (torch.Tensor 或 SparseTensor) – 边的索引。
edge_weight (torch.Tensor, optional) – 边的权重(如果底层的GNN层支持)。(默认值:
None)edge_attr (torch.Tensor, optional) – 边的特征(如果底层的GNN层支持)。(默认:
None)batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个元素分配给特定的示例。 只有在基础归一化 层需要
batch信息时才需要传递。 (默认:None)batch_size (int, optional) – 示例的数量 \(B\)。 如果未给出,则自动计算。 仅在底层归一化层需要
batch信息时需要传递。 (默认:None)num_sampled_nodes_per_hop (List[int], optional) – 每跳采样的节点数。 在
NeighborLoader场景中非常有用,仅对最小大小的表示进行操作。 (默认:None)num_sampled_edges_per_hop (List[int], optional) – 每跳采样的边数。 在
NeighborLoader场景中,仅对最小大小的表示进行操作时有用。 (默认:None)
- Return type:
- reset_parameters()
重置模块的所有可学习参数。