torch_geometric.nn.models.DeepGraphInfomax

class DeepGraphInfomax(hidden_channels: int, encoder: Module, summary: Callable, corruption: Callable)[source]

Bases: Module

来自“Deep Graph Infomax”论文的深度图信息最大化模型基于用户定义的编码器和摘要模型\(\mathcal{E}\)\(\mathcal{R}\),以及一个破坏函数\(\mathcal{C}\)

Parameters:
  • hidden_channels (int) – 潜在空间的维度。

  • encoder (torch.nn.Module) – 编码器模块 \(\mathcal{E}\)

  • 摘要 (可调用的) – 读取函数 \(\mathcal{R}\)

  • corruption (callable) – 腐败函数 \(\mathcal{C}\)

forward(*args, **kwargs) Tuple[Tensor, Tensor, Tensor][source]

返回输入参数、它们的损坏及其摘要表示的潜在空间。

Return type:

Tuple[Tensor, Tensor, Tensor]

reset_parameters()[source]

重置模块的所有可学习参数。

discriminate(z: Tensor, summary: Tensor, sigmoid: bool = True) Tensor[source]

给定补丁-摘要对 zsummary,计算分配给这对补丁-摘要的概率分数。

Parameters:
  • z (torch.Tensor) – 潜在空间。

  • 摘要 (torch.Tensor) – 摘要向量。

  • sigmoid (bool, optional) – If set to False, does not apply the logistic sigmoid function to the output. (default: True)

Return type:

Tensor

loss(pos_z: Tensor, neg_z: Tensor, summary: Tensor) Tensor[source]

计算互信息最大化目标。

Return type:

Tensor

test(train_z: Tensor, train_y: Tensor, test_z: Tensor, test_y: Tensor, solver: str = 'lbfgs', *args, **kwargs) float[source]

通过逻辑回归下游任务评估潜在空间的质量。

Return type:

float