torch_geometric.nn.models.LabelPropagation
- class LabelPropagation(num_layers: int, alpha: float)[source]
Bases:
MessagePassing标签传播操作符,首次在 “从带标签和无标签数据中学习标签传播” 论文中引入。
\[\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A} \mathbf{D}^{-1/2} \mathbf{Y} + (1 - \alpha) \mathbf{Y},\]其中未标记的数据通过传播由标记的数据推断出来。 这里的具体实现源自“结合标签传播和简单模型优于图神经网络”论文。
注意
有关使用
LabelPropagation的示例,请参见 examples/label_prop.py。- forward(y: Tensor, edge_index: Union[Tensor, SparseTensor], mask: Optional[Tensor] = None, edge_weight: Optional[Tensor] = None, post_step: Optional[Callable[[Tensor], Tensor]] = None) Tensor[source]
前向传播。
- Parameters:
y (torch.Tensor) – 真实标签信息 \(\mathbf{Y}\).
edge_index (torch.Tensor or SparseTensor) – The edge connectivity.
mask (torch.Tensor, optional) – 一个掩码或索引张量,表示哪些节点用于标签传播。 (默认:
None)edge_weight (torch.Tensor, optional) – The edge weights. (default:
None)post_step (可调用的, 可选的) – 一个在标签传播后应用的后处理步骤函数。如果没有指定后处理步骤函数,输出将被限制在0和1之间。 (默认:
None)
- Return type: