Triton 语义¶
Triton 主要遵循 NumPy 的语义,只有少量例外。本文档将介绍 Triton 支持的部分数组计算特性,并说明 Triton 与 NumPy 语义存在差异的地方。
类型提升¶
类型提升发生在使用不同数据类型的张量进行运算时。对于与dunder方法相关的二元运算以及三元函数tl.where的最后两个参数,Triton会自动按照类型层次结构(数据类型集合)将输入张量转换为共同的数据类型:{bool} < {integral dypes} < {floating point dtypes}。
算法如下:
类型 如果一个张量的数据类型属于更高阶的类型,另一个张量会被提升为该类型:
(int32, bfloat16) -> bfloat16宽度 如果两个张量的数据类型属于同一类别,且其中一个具有更高的位宽,则另一个会被提升为该数据类型:
(float32, float16) -> float32优先使用float16 如果两个张量具有相同的位宽和符号属性但数据类型不同(
float16和bfloat16或不同的fp8类型),它们都会被提升为float16。(float16, bfloat16) -> float16优先使用无符号类型 否则(相同位宽但符号不同时),它们会被提升为无符号数据类型:
(int32, uint32) -> uint32
当涉及标量时,规则会有些不同。这里的标量指的是数字字面量、标记为tl.constexpr的变量或这些的组合。它们由NumPy标量表示,类型为bool、int和float。
当一个操作涉及张量和标量时:
如果标量的类型低于或等于张量,它将不参与类型提升:
(uint8, int) -> uint8如果标量属于更高类型,我们会在整数类型
int32<uint32<int64<uint64和浮点类型float32<float64中选择能容纳该标量的最低精度数据类型。然后,张量和标量都会被提升到这个数据类型:(int16, 4.0) -> float32
广播¶
广播(Broadcasting)允许对不同形状的张量进行操作,通过自动扩展它们的形状到兼容的尺寸而不需要复制数据。这遵循以下规则:
如果其中一个张量形状较短,则在左侧填充1,直到两个张量具有相同的维度数:
((3, 4), (5, 3, 4)) -> ((1, 3, 4), (5, 3, 4))如果两个维度相等,或者其中一个维度为1,则这两个维度是兼容的。维度1将被扩展以匹配另一个张量的维度。
((1, 3, 4), (5, 3, 4)) -> ((5, 3, 4), (5, 3, 4))
与NumPy的差异¶
C语言整数除法中的舍入规则 Triton中的运算符遵循C语言语义而非Python语义以提高效率。因此,int // int对于混合符号的整数实现了向零舍入的C语言规则,而不是像Python那样向负无穷舍入。出于同样的原因,取模运算符int % int(定义为a % b = a - b * (a // b))也遵循C语言语义而非Python语义。
可能令人困惑的是,当所有输入都是标量时,整数除法和模运算遵循Python语义。