torch_geometric.nn.models.TGNMemory

class TGNMemory(num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable)[source]

Bases: Module

来自“Temporal Graph Networks for Deep Learning on Dynamic Graphs”论文的时间图网络(TGN)记忆模型。

注意

有关使用TGN的示例,请参见examples/tgn.py

Parameters:
  • num_nodes (int) – 要保存记忆的节点数量。

  • raw_msg_dim (int) – 原始消息的维度。

  • memory_dim (int) – 隐藏内存的维度。

  • time_dim (int) – 时间编码的维度。

  • message_module (torch.nn.Module) – 消息函数,用于结合源节点和目标节点的内存嵌入、原始消息和时间编码。

  • aggregator_module (torch.nn.Module) – 消息聚合函数,将发送到同一目的地的消息聚合成一个单一的表示。

forward(n_id: Tensor) Tuple[Tensor, Tensor][source]

返回所有节点 n_id 的当前内存和最后更新时间戳。

Return type:

Tuple[Tensor, Tensor]

reset_parameters()[source]

重置模块的所有可学习参数。

reset_state()[source]

将内存重置为其初始状态。

detach()[source]

将内存从梯度计算中分离。

update_state(src: Tensor, dst: Tensor, t: Tensor, raw_msg: Tensor)[source]

使用新遇到的交互更新内存 (src, dst, t, raw_msg)

train(mode: bool = True)[source]

将模块设置为训练模式。