triton.language.dot_scaled

triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, fast_math=False, lhs_k_pack=True, rhs_k_pack=True, out_dtype=triton.language.float32)

返回微缩格式下两个矩阵块的矩阵乘积。

lhs和rhs使用此处描述的微缩放格式: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

软件模拟使得能够针对不支持原生微缩放操作的硬件架构进行目标适配。目前对于这种情况,微缩放的左侧/右侧操作数会预先上转为bf16元素类型以进行点积计算,但有一个例外:特别针对AMD CDNA3架构,如果其中一个输入是fp16元素类型,则另一个输入也会被上转为fp16元素类型而非其他类型。此行为是实验性的,未来可能会发生变化。

Parameters:
  • lhs (表示fp4、fp8或bf16元素的2D张量。Fp4元素被打包到uint8输入中,第一个元素位于低位。Fp8存储为uint8或对应的fp8类型。) – 第一个要相乘的张量。

  • lhs_scale (以uint8张量表示的e8m0类型。) – lhs张量的缩放因子。

  • lhs_format (str) – lhs张量的格式。可用格式:{e2m1, e4m3, e5m2, bf16, fp16}。

  • rhs (表示fp4fp8bf16元素的2D张量。Fp4元素被打包到uint8输入中,第一个元素位于低位。Fp8存储为uint8对应的fp8类型。) – 要相乘的第二个张量。

  • rhs_scale (以uint8张量表示的e8m0类型。) – rhs张量的缩放因子。

  • rhs_format (str) - rhs张量的格式。可用格式:{e2m1, e4m3, e5m2, bf16, fp16}。

  • acc – 累加器张量。如果不为None,结果将添加到此张量中。

  • lhs_k_pack (bool, optional) – 如果为false,则lhs张量将沿M维度打包为uint8类型。

  • rhs_k_pack (bool, optional) – 如果为false,则rhs张量会沿N维度打包为uint8类型。