torch_geometric.nn.aggr.QuantileAggregation

class QuantileAggregation(q: Union[float, List[float]], interpolation: str = 'linear', fill_value: float = 0.0)[source]

Bases: Aggregation

一个聚合操作符,返回集合 \(\mathcal{X}\) 的特征方向的 \(q\) 分位数。

也就是说,对于每个特征 \(d\),它计算

\[\begin{split}{\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} x_{\pi_i,d} & i = q \cdot n, \\ f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ \end{cases}\end{split}\]

其中 \(x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le x_{\pi_n,d}\) 并且 \(f(a, b)\) 是由 interpolation 定义的插值函数。

Parameters:
  • q (floatlist) – 分位数值 \(q\)。可以是一个标量或 一个标量列表,范围在 \([0, 1]\) 之间。如果传递了多个 分位数,结果将被连接起来。

  • 插值 (str) –

    如果分位数点 \(q\cdot n\) 位于两个值 \(a \le b\) 之间,则应用插值方法。可以是以下之一:

    • "lower": 返回最小值。

    • "higher": 返回最大值。

    • "midpoint": 返回两个值的平均值。

    • "nearest": 返回索引最接近分位数的值。

    • "linear": 返回两个元素的线性组合,定义为 \(f(a, b) = a + (b - a)\cdot(q\cdot n - i)\)

    (默认: "linear")

  • fill_value (float, optional) – 在给定索引找不到条目时的默认值(默认值:0.0)。

forward(x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) Tensor[source]

前向传播。

Parameters:
  • x (torch.Tensor) – The source tensor.

  • index (torch.Tensor, optional) – The indices of elements for applying the aggregation. One of index or ptr must be defined. (default: None)

  • ptr (torch.Tensor, optional) – If given, computes the aggregation based on sorted inputs in CSR representation. One of index or ptr must be defined. (default: None)

  • dim_size (int, optional) – The size of the output tensor at dimension dim after aggregation. (default: None)

  • dim (int, optional) – The dimension in which to aggregate. (default: -2)

  • max_num_elements – (int, optional): The maximum number of elements within a single aggregation group. (default: None)

Return type:

Tensor