triton.language.reduce¶
- triton.language.reduce(input, axis, combine_fn, keep_dims=False)¶
将combine_fn应用于
input
张量中沿指定axis
的所有元素- Parameters:
input (Tensor) – 输入张量,或张量元组
axis (int | None) - 指定进行归约操作的维度。如果为None,则对所有维度进行归约
combine_fn (Callable) - 一个用于合并两组标量张量的函数(必须用@triton.jit标记)
keep_dims (bool) - 如果为true,则保留长度为1的缩减维度
该函数也可以作为成员函数在
tensor
上调用, 使用x.reduce(...)
而非reduce(x, ...)
。