torch_geometric.nn.models.TGNMemory
- class TGNMemory(num_nodes: int, raw_msg_dim: int, memory_dim: int, time_dim: int, message_module: Callable, aggregator_module: Callable)[source]
Bases:
Module来自“Temporal Graph Networks for Deep Learning on Dynamic Graphs”论文的时间图网络(TGN)记忆模型。
注意
有关使用TGN的示例,请参见examples/tgn.py。
- Parameters:
num_nodes (int) – 要保存记忆的节点数量。
raw_msg_dim (int) – 原始消息的维度。
memory_dim (int) – 隐藏内存的维度。
time_dim (int) – 时间编码的维度。
message_module (torch.nn.Module) – 消息函数,用于结合源节点和目标节点的内存嵌入、原始消息和时间编码。
aggregator_module (torch.nn.Module) – 消息聚合函数,将发送到同一目的地的消息聚合成一个单一的表示。