triton.language.dot

triton.language.dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=triton.language.float32)

返回两个块的矩阵乘积。

这两个块必须都是二维或三维的,并且具有兼容的内部维度。 对于三维块,tl.dot执行批量矩阵乘积, 其中每个块的第一个维度代表批量维度。

Parameters:
  • input (2D或3D张量,标量类型为{int8, float8_e5m2, float16, bfloat16, float32}) - 第一个待相乘的张量。

  • other (2D或3D张量,标量类型为{int8, float8_e5m2, float16, bfloat16, float32}) - 要相乘的第二个张量。

  • acc (2D或3D张量,标量类型为{float16, float32, int32}) - 累加器张量。如果不为None,结果将被添加到此张量中。

  • input_precision (字符串。NVIDIA可用选项:"tf32", "tf32x3", "ieee"。默认值:"tf32"。AMD可用选项:"ieee", (仅CDNA3) "tf32"。) – 如何为f32 x f32运算使用张量核心。如果设备没有张量核心或输入不是f32类型,此选项将被忽略。对于具有张量核心的设备,默认精度为tf32。

  • allow_tf32已弃用。 如果为true,input_precision将被设置为"tf32"。 input_precisionallow_tf32中只能指定一个(即至少有一个必须是None)。