torch_geometric.nn.models.GAE

class GAE(encoder: Module, decoder: Optional[Module] = None)[source]

Bases: Module

图自动编码器模型来自 “变分图自动编码器” 论文,基于用户定义的编码器和解码器模型。

Parameters:
forward(*args, **kwargs) Tensor[source]

encode() 的别名。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。

encode(*args, **kwargs) Tensor[source]

运行编码器并计算节点级别的潜在变量。

Return type:

Tensor

decode(*args, **kwargs) Tensor[source]

运行解码器并计算边缘概率。

Return type:

Tensor

recon_loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) Tensor[source]

给定潜在变量 z,计算正边 pos_edge_index 和负采样边的二元交叉熵损失。

Parameters:
  • z (torch.Tensor) – 潜在空间 \(\mathbf{Z}\)

  • pos_edge_index (torch.Tensor) – 用于训练的正向边。

  • neg_edge_index (torch.Tensor, optional) – 用于训练的负边。如果未提供,则使用负采样来计算负边。(默认值:None

Return type:

Tensor

test(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tuple[Tensor, Tensor][source]

给定潜在变量 z,正边 pos_edge_index 和负边 neg_edge_index, 计算ROC曲线下面积(AUC)和平均精度(AP) 分数。

Parameters:
Return type:

Tuple[Tensor, Tensor]