torch_geometric.nn.dense.DenseSAGEConv

class DenseSAGEConv(in_channels: int, out_channels: int, normalize: bool = False, bias: bool = True)[源代码]

基础类: Module

参见 torch_geometric.nn.conv.SAGEConv

注意

DenseSAGEConv 期望在二进制邻接矩阵上工作。 如果您想使用加权的密集邻接矩阵,请使用 torch_geometric.nn.dense.DenseGraphConv 代替。

reset_parameters()[source]

重置模块的所有可学习参数。

forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tensor[源代码]

前向传播。

Parameters:
  • x (torch.Tensor) – 节点特征张量 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), 其中 批次大小 \(B\), (最大) 节点数 \(N\) 对于 每个图,以及特征维度 \(F\).

  • adj (torch.Tensor) – 邻接张量 \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). 邻接张量在批次维度上是可广播的, 从而为整个批次共享一个邻接矩阵。

  • mask (torch.Tensor, optional) – 掩码矩阵 \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) 表示 每个图中有效的节点。(默认: None)

Return type:

Tensor