torch_geometric.nn.models.MaskLabel

class MaskLabel(num_classes: int, out_channels: int, method: str = 'add')[source]

Bases: Module

来自“Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification”论文的标签嵌入和掩码层。

在这里,节点标签 y 根据 mask 合并到初始节点特征 x 中,针对它们的节点子集。

注意

有关使用 MaskLabel 的示例,请参见 examples/unimp_arxiv.py

Parameters:
  • num_classes (int) – 类别数量。

  • out_channels (int) – Size of each output sample.

  • method (str, optional) – 如果设置为 "add",标签嵌入会被添加到输入中。如果设置为 "concat",标签嵌入会被连接。在 method="add" 的情况下,out_channels 需要与节点特征的输入维度相同。(默认值:"add"

forward(x: Tensor, y: Tensor, mask: Tensor) Tensor[source]
Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。

static ratio_mask(mask: Tensor, ratio: float)[source]

通过将True条目的ratio设置为False来修改mask。不会在原地操作。

Parameters:
  • mask (torch.Tensor) – 用于重新掩码的掩码。

  • ratio (float) – 要保留的条目比例。