torch_geometric.nn.models.GAE
- class GAE(encoder: Module, decoder: Optional[Module] = None)[source]
Bases:
Module图自动编码器模型来自 “变分图自动编码器” 论文,基于用户定义的编码器和解码器模型。
- Parameters:
encoder (torch.nn.Module) – 编码器模块。
decoder (torch.nn.Module, optional) – 解码器模块。如果设置为
None,将默认使用torch_geometric.nn.models.InnerProductDecoder。 (默认值:None)
- recon_loss(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Optional[Tensor] = None) Tensor[source]
给定潜在变量
z,计算正边pos_edge_index和负采样边的二元交叉熵损失。- Parameters:
z (torch.Tensor) – 潜在空间 \(\mathbf{Z}\)。
pos_edge_index (torch.Tensor) – 用于训练的正向边。
neg_edge_index (torch.Tensor, optional) – 用于训练的负边。如果未提供,则使用负采样来计算负边。(默认值:
None)
- Return type:
- test(z: Tensor, pos_edge_index: Tensor, neg_edge_index: Tensor) Tuple[Tensor, Tensor][source]
给定潜在变量
z,正边pos_edge_index和负边neg_edge_index, 计算ROC曲线下面积(AUC)和平均精度(AP) 分数。- Parameters:
z (torch.Tensor) – 潜在空间 \(\mathbf{Z}\).
pos_edge_index (torch.Tensor) – 用于评估的正向边。
neg_edge_index (torch.Tensor) – 用于评估的负边。
- Return type: