torch_geometric.nn.aggr.GraphMultisetTransformer
- class GraphMultisetTransformer(channels: int, k: int, num_encoder_blocks: int = 1, heads: int = 1, layer_norm: bool = False, dropout: float = 0.0)[source]
基础类:
Aggregation来自“Accurate Learning of Graph Representations with Graph Multiset Pooling”论文的图多集变换器池化操作符。
GraphMultisetTransformer通过基于注意力的池化将元素聚合为 \(k\) 个代表性元素,通过num_encoder_blocks自注意力块计算它们之间的交互, 最后通过基于注意力的池化将代表性元素聚合成一个单一的簇。注意
GraphMultisetTransformer需要排序后的索引index作为输入。具体来说,如果你将此聚合作为MessagePassing的一部分使用,请确保edge_index按目标节点排序,可以通过手动使用sort_edge_index()排序边索引,或者通过调用torch_geometric.data.Data.sort()来实现。- Parameters:
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, max_num_elements: Optional[int] = None) Tensor[source]
前向传播。
- Parameters:
x (torch.Tensor) – 源张量。
index (torch.Tensor, optional) – 用于应用聚合的元素的索引。 必须定义
index或ptr中的一个。 (默认值:None)ptr (torch.Tensor, optional) – 如果提供,将基于CSR表示中的排序输入计算聚合。 必须定义
index或ptr中的一个。 (默认:None)dim (int, optional) – 聚合的维度。 (default:
-2)max_num_elements (
Optional[int], 默认值:None) – (int, 可选): 单个聚合组中的最大元素数量。(默认值:None)
- Return type: