torch_geometric.nn.models.CorrectAndSmooth

class CorrectAndSmooth(num_correction_layers: int, correction_alpha: float, num_smoothing_layers: int, smoothing_alpha: float, autoscale: bool = True, scale: float = 1.0)[source]

基类:Module

来自“Combining Label Propagation And Simple Models Out-performs Graph Neural Networks”论文的正确且平滑(C&S)后处理模型,其中软预测\(\mathbf{Z}\)(从简单的基础预测器获得)首先基于真实训练标签信息\(\mathbf{Y}\)和残差传播进行校正。

\[\begin{split}\mathbf{e}^{(0)}_i &= \begin{cases} \mathbf{y}_i - \mathbf{z}_i, & \text{if }i \text{ is training node,}\\ \mathbf{0}, & \text{else} \end{cases}\end{split}\]
\[ \begin{align}\begin{aligned}\mathbf{E}^{(\ell)} &= \alpha_1 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{E}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{E}^{(\ell - 1)}\\\mathbf{\hat{Z}} &= \mathbf{Z} + \gamma \cdot \mathbf{E}^{(L_1)},\end{aligned}\end{align} \]

其中 \(\gamma\) 表示缩放因子(可以是固定的或自动确定的),然后通过标签传播在图上进行平滑处理

\[\begin{split}\mathbf{\hat{z}}^{(0)}_i &= \begin{cases} \mathbf{y}_i, & \text{if }i\text{ is training node,}\\ \mathbf{\hat{z}}_i, & \text{else} \end{cases}\end{split}\]
\[\mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} + (1 - \alpha_2) \mathbf{\hat{Z}}^{(\ell - 1)}\]

获得最终预测 \(\mathbf{\hat{Z}}^{(L_2)}\)

注意

有关使用C&S模型的示例,请参见 examples/correct_and_smooth.py

Parameters:
  • num_correction_layers (int) – 传播次数 \(L_1\)

  • correction_alpha (float) – 系数 \(\alpha_1\)

  • num_smoothing_layers (int) – 传播次数 \(L_2\)

  • smoothing_alpha (float) – 这是\(\alpha_2\)系数。

  • autoscale (bool, optional) – 如果设置为 True,将自动确定缩放因子 \(\gamma\)。(默认值:True

  • scale (float, 可选) – 缩放因子 \(\gamma\),在 autoscale = False 的情况下。(默认值: 1.0)

forward(y_soft: Tensor, *args) Tensor[source]

同时应用correct()smooth()

Return type:

Tensor

correct(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

前向传播。

Parameters:
  • y_soft (torch.Tensor) – 从简单的基础预测器获得的软预测 \(\mathbf{Z}\)

  • y_true (torch.Tensor) – 训练节点的真实标签信息 \(\mathbf{Y}\)

  • mask (torch.Tensor) – 一个掩码或索引张量,表示哪些节点用于训练。

  • edge_index (torch.TensorSparseTensor) – 边的连接性。

  • edge_weight (torch.Tensor, optional) – 边的权重。 (默认: None)

Return type:

Tensor

smooth(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

前向传播。

Parameters:
  • y_soft (torch.Tensor) – 从correct()获得的修正预测 \(\mathbf{Z}\)

  • y_true (torch.Tensor) – 训练节点的真实标签信息 \(\mathbf{Y}\)

  • mask (torch.Tensor) – 一个掩码或索引张量,表示哪些节点用于训练。

  • edge_index (torch.TensorSparseTensor) – 边的连接性。

  • edge_weight (torch.Tensor, optional) – 边的权重。 (默认: None)

Return type:

Tensor