torch_geometric.nn.models.CorrectAndSmooth
- class CorrectAndSmooth(num_correction_layers: int, correction_alpha: float, num_smoothing_layers: int, smoothing_alpha: float, autoscale: bool = True, scale: float = 1.0)[source]
基类:
Module来自“Combining Label Propagation And Simple Models Out-performs Graph Neural Networks”论文的正确且平滑(C&S)后处理模型,其中软预测\(\mathbf{Z}\)(从简单的基础预测器获得)首先基于真实训练标签信息\(\mathbf{Y}\)和残差传播进行校正。
\[\begin{split}\mathbf{e}^{(0)}_i &= \begin{cases} \mathbf{y}_i - \mathbf{z}_i, & \text{if }i \text{ is training node,}\\ \mathbf{0}, & \text{else} \end{cases}\end{split}\]\[ \begin{align}\begin{aligned}\mathbf{E}^{(\ell)} &= \alpha_1 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{E}^{(\ell - 1)} + (1 - \alpha_1) \mathbf{E}^{(\ell - 1)}\\\mathbf{\hat{Z}} &= \mathbf{Z} + \gamma \cdot \mathbf{E}^{(L_1)},\end{aligned}\end{align} \]其中 \(\gamma\) 表示缩放因子(可以是固定的或自动确定的),然后通过标签传播在图上进行平滑处理
\[\begin{split}\mathbf{\hat{z}}^{(0)}_i &= \begin{cases} \mathbf{y}_i, & \text{if }i\text{ is training node,}\\ \mathbf{\hat{z}}_i, & \text{else} \end{cases}\end{split}\]\[\mathbf{\hat{Z}}^{(\ell)} = \alpha_2 \mathbf{D}^{-1/2}\mathbf{A} \mathbf{D}^{-1/2} \mathbf{\hat{Z}}^{(\ell - 1)} + (1 - \alpha_2) \mathbf{\hat{Z}}^{(\ell - 1)}\]获得最终预测 \(\mathbf{\hat{Z}}^{(L_2)}\)。
注意
有关使用C&S模型的示例,请参见 examples/correct_and_smooth.py。
- Parameters:
num_correction_layers (int) – 传播次数 \(L_1\)。
correction_alpha (float) – 系数 \(\alpha_1\)。
num_smoothing_layers (int) – 传播次数 \(L_2\)。
smoothing_alpha (float) – 这是\(\alpha_2\)系数。
autoscale (bool, optional) – 如果设置为
True,将自动确定缩放因子 \(\gamma\)。(默认值:True)scale (float, 可选) – 缩放因子 \(\gamma\),在
autoscale = False的情况下。(默认值:1.0)
- correct(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
前向传播。
- Parameters:
y_soft (torch.Tensor) – 从简单的基础预测器获得的软预测 \(\mathbf{Z}\)。
y_true (torch.Tensor) – 训练节点的真实标签信息 \(\mathbf{Y}\)。
mask (torch.Tensor) – 一个掩码或索引张量,表示哪些节点用于训练。
edge_index (torch.Tensor 或 SparseTensor) – 边的连接性。
edge_weight (torch.Tensor, optional) – 边的权重。 (默认:
None)
- Return type:
- smooth(y_soft: Tensor, y_true: Tensor, mask: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]
前向传播。
- Parameters:
y_soft (torch.Tensor) – 从
correct()获得的修正预测 \(\mathbf{Z}\)。y_true (torch.Tensor) – 训练节点的真实标签信息 \(\mathbf{Y}\)。
mask (torch.Tensor) – 一个掩码或索引张量,表示哪些节点用于训练。
edge_index (torch.Tensor 或 SparseTensor) – 边的连接性。
edge_weight (torch.Tensor, optional) – 边的权重。 (默认:
None)
- Return type: