torch_geometric.nn.conv.ChebConv

class ChebConv(in_channels: int, out_channels: int, K: int, normalization: Optional[str] = 'sym', bias: bool = True, **kwargs)[source]

Bases: MessagePassing

来自“Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering”论文的切比雪夫谱图卷积算子。

\[\mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot \mathbf{\Theta}^{(k)}\]

其中 \(\mathbf{Z}^{(k)}\) 是通过递归计算的

\[ \begin{align}\begin{aligned}\mathbf{Z}^{(1)} &= \mathbf{X}\\\mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X}\\\mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)}\end{aligned}\end{align} \]

并且 \(\mathbf{\hat{L}}\) 表示缩放和归一化的拉普拉斯矩阵 \(\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}\)

Parameters:
  • in_channels (int) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method.

  • out_channels (int) – Size of each output sample.

  • K (int) – 切比雪夫滤波器大小 \(K\)

  • normalization (str, optional) –

    图的拉普拉斯矩阵的归一化方案(默认:"sym"):

    1. None: 无归一化 \(\mathbf{L} = \mathbf{D} - \mathbf{A}\)

    2. "sym": 对称归一化 \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2}\)

    3. "rw": 随机游走归一化 \(\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}\)

    lambda_max 应该是一个 torch.Tensor,在小批量场景下大小为 [num_graphs],在单个图上操作时为标量/零维张量。 你可以通过 torch_geometric.transforms.LaplacianLambdaMax 变换预先计算 lambda_max

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F_{in})\), 边索引 \((2, |\mathcal{E}|)\), 边权重 \((|\mathcal{E}|)\) (可选), 批次向量 \((|\mathcal{V}|)\) (可选), 最大 lambda\((|\mathcal{G}|)\) (可选)

  • output: node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None, batch: Optional[Tensor] = None, lambda_max: Optional[Tensor] = None) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。