triton.language

编程模型

tensor

表示一个N维的值数组或指针数组。

program_id

返回当前程序实例沿给定axis轴的ID。

num_programs

返回沿给定axis轴启动的程序实例数量。

创建操作

arange

返回半开区间 [start, end) 内的连续值。

cat

连接给定的块

full

返回一个根据给定的shapedtype填充标量值的张量。

zeros

返回一个用标量值0填充的张量,根据给定的shapedtype

zeros_like

返回一个与给定张量形状和类型相同的全零张量。

cast

将张量转换为指定的 dtype

形状操作运算

broadcast

尝试将给定的两个块广播到一个共同的兼容形状。

broadcast_to

尝试将给定的张量广播到新的shape

expand_dims

通过插入新的长度为1的维度来扩展张量的形状。

interleave

沿着最后一个维度交错两个张量的值。

join

在新增的次要维度上连接给定的张量。

permute

对张量的维度进行排列。

ravel

返回x的一个连续扁平化视图。

reshape

返回一个与输入元素数量相同但具有指定形状的张量。

split

沿着张量的最后一个维度将其分成两部分,该维度的大小必须为2。

trans

对张量的维度进行排列。

view

返回一个与input具有相同元素但形状不同的张量。

线性代数运算

dot

返回两个块的矩阵乘积。

dot_scaled

返回微缩格式下两个矩阵块的矩阵乘积。

内存/指针操作

load

返回一个数据张量,其值从由pointer定义的内存位置加载:

store

将数据张量存储到由pointer定义的内存位置中。

make_block_ptr

返回指向父张量中某块的指针

advance

推进块指针

索引操作

flip

沿维度dim翻转张量x

where

根据condition条件,返回一个由xy中的元素组成的张量。

swizzle2d

将行优先size_i * size_j矩阵的索引转换为每组size_g行的列优先矩阵索引。

数学运算

abs

计算 x 的逐元素绝对值。

cdiv

计算x除以div的向上取整结果

ceil

计算 x 逐元素的上限值。

clamp

将输入张量 x 限制在 [min, max] 范围内。

cos

计算x的逐元素余弦值。

div_rn

计算xy的逐元素精确除法(根据IEEE标准四舍五入)。

erf

计算x的逐元素误差函数。

exp

计算x的逐元素指数。

exp2

计算x逐元素的以2为底的指数。

fdiv

计算xy的逐元素快速除法。

floor

计算x逐元素的下取整。

fma

计算xyz的逐元素融合乘加运算。

log

计算x逐元素的自然对数。

log2

计算x的逐元素对数(以2为底)。

maximum

计算xy的逐元素最大值。

minimum

计算xy逐元素的最小值。

rsqrt

计算x的逐元素平方根倒数。

sigmoid

计算x的逐元素sigmoid函数值。

sin

计算x的逐元素正弦值。

softmax

计算x的逐元素softmax。

sqrt

计算x的逐元素快速平方根。

sqrt_rn

计算x逐元素的精确平方根(根据IEEE标准四舍五入到最近值)。

umulhi

计算xy的2N位乘积中每个元素最高有效的N位。

归约操作

argmax

返回沿指定axis轴方向上input张量中所有元素的最大索引

argmin

返回沿指定axis方向上input张量中所有元素的最小索引

max

返回沿指定axis方向上input张量中所有元素的最大值

min

返回input张量中沿指定axis的所有元素的最小值

reduce

将combine_fn应用于input张量中沿指定axis的所有元素

sum

返回input张量沿指定axis轴上所有元素的总和

xor_sum

返回input张量中所有元素沿指定axis的异或和

扫描/排序操作

associative_scan

将combine_fn应用于input张量中沿指定axis的每个元素,并更新carry值

cumprod

返回input张量中所有元素沿指定axis的累积乘积

cumsum

返回input张量中所有元素沿指定axis轴的累加和

histogram

基于输入张量计算具有num_bins个分箱的直方图,分箱宽度为1且从0开始。

sort

gather

沿给定维度从张量中收集数据。

原子操作

atomic_add

pointer指定的内存位置执行原子加法操作。

atomic_and

pointer指定的内存位置执行原子逻辑与操作。

atomic_cas

pointer指定的内存位置执行原子比较并交换操作。

atomic_max

pointer指定的内存位置执行原子最大值操作。

atomic_min

pointer指定的内存位置执行原子最小值操作。

atomic_or

pointer指定的内存位置执行原子逻辑或操作。

atomic_xchg

pointer指定的内存位置执行原子交换操作。

atomic_xor

pointer指定的内存位置执行原子逻辑异或操作。

随机数生成

randint4x

给定一个seed标量和一个offset块,返回四个随机int32块。

randint

给定一个seed标量和一个offset块,返回一个随机的int32块。

rand

给定一个seed标量和一个offset块,返回一个在\(U(0, 1)\)范围内的随机float32块。

randn

给定一个seed标量和一个offset块,返回一个在\(\mathcal{N}(0, 1)\)范围内的随机float32块。

迭代器

range

一个永远向上计数的迭代器。

static_range

一个永远向上计数的迭代器。

内联汇编

inline_asm_elementwise

在张量上执行内联汇编。

编译器提示操作

assume

允许编译器假设 cond 为真。

debug_barrier

插入一个屏障来同步块中的所有线程。

max_constancy

让编译器知道input中的前value个值是常量。

max_contiguous

让编译器知道input中的前value个值是连续的。

multiple_of

让编译器知道input中的值都是value的倍数。

调试操作

static_print

在编译时打印值。

static_assert

在编译时断言条件。

device_print

在运行时从设备打印数值。

device_assert

在设备运行时断言条件。