torch_geometric.nn.dense.DenseSAGEConv
- class DenseSAGEConv(in_channels: int, out_channels: int, normalize: bool = False, bias: bool = True)[源代码]
基础类:
Module参见
torch_geometric.nn.conv.SAGEConv。注意
DenseSAGEConv期望在二进制邻接矩阵上工作。 如果您想使用加权的密集邻接矩阵,请使用torch_geometric.nn.dense.DenseGraphConv代替。- forward(x: Tensor, adj: Tensor, mask: Optional[Tensor] = None) Tensor[源代码]
前向传播。
- Parameters:
x (torch.Tensor) – 节点特征张量 \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), 其中 批次大小 \(B\), (最大) 节点数 \(N\) 对于 每个图,以及特征维度 \(F\).
adj (torch.Tensor) – 邻接张量 \(\mathbf{A} \in \mathbb{R}^{B \times N \times N}\). 邻接张量在批次维度上是可广播的, 从而为整个批次共享一个邻接矩阵。
mask (torch.Tensor, optional) – 掩码矩阵 \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) 表示 每个图中有效的节点。(默认:
None)
- Return type: