标签传播

class dgl.nn.pytorch.utils.LabelPropagation(k, alpha, norm_type='sym', clamp=True, normalize=False, reset=False)[source]

Bases: Module

标签传播来自 从有标签和无标签数据中学习标签传播

\[\mathbf{Y}^{(t+1)} = \alpha \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha) \mathbf{Y}^{(0)}\]

其中未标记的数据最初设置为零,并通过传播从标记的数据中推断出来。\(\alpha\) 是一个权重参数,用于平衡更新后的标签和初始标签。\(\tilde{A}\) 表示归一化的邻接矩阵。

Parameters:
  • k (int) – 传播步骤的数量。

  • alpha (float) – 范围在 [0, 1] 内的 \(\alpha\) 系数。

  • norm_type (str, optional) –

    应用于邻接矩阵的归一化类型,必须是以下选项之一:

    • row: 行归一化邻接矩阵,如 \(D^{-1}A\)

    • sym: 对称归一化邻接矩阵,如 \(D^{-1/2}AD^{-1/2}\)

    默认值:'sym'。

  • clamp (bool, optional) – 一个布尔标志,用于指示在传播后是否将标签限制在[0, 1]范围内。 默认值:True。

  • normalize (bool, optional) – 一个布尔标志,用于指示在传播后是否应用行归一化。 默认值:False。

  • reset (bool, optional) – 一个布尔标志,用于指示是否在每次传播步骤后重置已知标签。默认值:False。

示例

>>> import torch
>>> import dgl
>>> from dgl.nn import LabelPropagation
>>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True)
>>> g = dgl.rand_graph(5, 10)
>>> labels = torch.tensor([0, 2, 1, 3, 0]).long()
>>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()
>>> new_labels = label_propagation(g, labels, mask)
forward(g, labels, mask=None)[source]

计算标签传播过程。

Parameters:
  • g (DGLGraph) – The input graph.

  • labels (torch.Tensor) –

    输入的节点标签。支持三种情况。

    • 形状为 \((N, 1)\)\((N,)\) 的 LongTensor,用于多类分类中的节点类别标签,其中 \(N\) 是节点数量。

    • 形状为 \((N, C)\) 的 LongTensor,用于多类分类中的节点类别标签的 one-hot 编码,其中 \(C\) 是类别数量。

    • 形状为 \((N, L)\) 的 LongTensor,用于多标签二元分类中的节点标签,其中 \(L\) 是标签数量。

  • mask (torch.Tensor) – 形状为 \((N,)\) 的布尔指示器,True 表示标记的节点。 默认值:None,表示所有节点都已标记。

Returns:

传播的节点标签的形状为 \((N, D)\),类型为浮点型,其中 \(D\) 是类别或标签的数量。

Return type:

torch.Tensor