torch_geometric.nn.aggr.MultiAggregation
- class MultiAggregation(aggrs: List[Union[Aggregation, str]], aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, mode: Optional[str] = 'cat', mode_kwargs: Optional[Dict[str, Any]] = None)[source]
Bases:
Aggregation执行一个或多个聚合器的聚合并组合聚合结果,如“Principal Neighbourhood Aggregation for Graph Nets”和“Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions”论文中所述。
- Parameters:
aggrs (list) – 要使用的聚合方案列表。
aggrs_kwargs (dict, optional) – 在自动解析的情况下传递给相应聚合函数的参数。(默认值:
None)mode (str, optional) – 用于组合多个聚合结果的组合模式 (
"cat","proj","sum","mean","max","min","logsumexp","std","var","attn"). (默认:"cat")mode_kwargs (dict, optional) – 传递给组合
mode的参数。当使用"proj"或"attn"作为组合mode时,需要分别指定in_channels(int或tuple)和out_channels(int),用于指定从各自的聚合输出中组合的每个输入样本的大小以及组合后每个输出样本的大小。当使用"attn"模式时,需要指定num_heads(int)以指定并行注意力头的数量。(默认值:None)
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
前向传播。
- Parameters:
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
indexorptrmust be defined. (default:None)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
indexorptrmust be defined. (default:None)dim_size (int, optional) – The size of the output tensor at dimension
dimafter aggregation. (default:None)dim (int, optional) – The dimension in which to aggregate. (default:
-2)max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default:
None)
- Return type: