注意力边缘权重

class AttentionEdgeWeighting(message_dim: int, num_heads: int = 8, dropout: float = 0.1)[源代码]

基础:EdgeWeighting

通过注意力机制进行消息加权。

初始化模块。

Parameters:
  • message_dim (int) – >0 消息维度。必须能被num_heads整除 .. todo:: 改为乘法而不是除法,以便更容易使用

  • num_heads (int) – >0 注意力头的数量

  • dropout (float) – 注意力机制的dropout

Raises:

ValueError – 如果 message_dim 不能被 num_heads 整除

属性摘要

needs_message

边缘加权是否需要访问消息

方法总结

forward(source, target[, message, x_e])

计算边的权重。

属性文档

needs_message: ClassVar[bool] = True

边缘加权是否需要访问消息

方法文档

forward(source: Tensor, target: Tensor, message: Tensor | None = None, x_e: Tensor | None = None) Tensor[来源]

计算边的权重。

Parameters:
  • source (Tensor) – 形状: (num_edges,) 源索引。

  • target (Tensor) – 形状: (num_edges,) 目标索引。

  • 消息 (Tensor | None) – 形状 (num_edges, dim) 实际要加权的消息

  • x_e (Tensor | None) – 形状 (num_nodes, dim) 节点状态直到加权点

Returns:

形状: (num_edges, dim) 使用边权重加权的消息。

Return type:

Tensor