triton.language.associative_scan¶
- triton.language.associative_scan(input, axis, combine_fn, reverse=False)¶
将combine_fn应用于
input张量中沿指定axis的每个元素,并更新carry值- Parameters:
input (Tensor) – 输入张量,或张量元组
axis (int) - 指定进行缩减操作的维度
combine_fn (Callable) - 一个用于合并两组标量张量的函数(必须用@triton.jit标记)
reverse (bool) - 是否沿轴以反向方向应用关联扫描
该函数也可以作为成员函数在
tensor上调用, 使用x.associative_scan(...)而非associative_scan(x, ...)的形式。