torch_geometric.nn.aggr.QuantileAggregation
- class QuantileAggregation(q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0)[source]
Bases:
Aggregation一个聚合操作符,返回集合 \(\mathcal{X}\) 的特征方向的 \(q\) 分位数。
也就是说,对于每个特征 \(d\),它计算
\[\begin{split}{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases}\end{split}\]其中 \(x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}\) 并且 \(f(a, b)\) 是由
interpolation定义的插值函数。- Parameters:
q (float 或 list) – 分位数值 \(q\)。可以是一个标量或 一个标量列表,范围在 \([0, 1]\) 之间。如果传递了多个 分位数,结果将被连接起来。
插值 (str) –
如果分位数点 \(q\cdot n\) 位于两个值 \(a \le b\) 之间,则应用插值方法。可以是以下之一:
"lower": 返回最小值。"higher": 返回最大值。"midpoint": 返回两个值的平均值。"nearest": 返回索引最接近分位数的值。"linear": 返回两个元素的线性组合,定义为 \(f(a, b) = a + (b - a)\cdot(q\cdot n - i)\)。
(默认:
"linear")fill_value (float, optional) – 在给定索引找不到条目时的默认值(默认值:
0.0)。
- forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]
前向传播。
- Parameters:
x (torch.Tensor) – The source tensor.
index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of
indexorptrmust be defined. (default:None)ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of
indexorptrmust be defined. (default:None)dim_size (int, optional) – The size of the output tensor at dimension
dimafter aggregation. (default:None)dim (int, optional) – The dimension in which to aggregate. (default:
-2)max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default:
None)
- Return type: