triton.language.expand_dims

triton.language.expand_dims(input, axis)

通过插入新的长度为1的维度来扩展张量的形状。

轴索引是相对于结果张量的,因此对于每个轴,result.shape[axis]的值将为1。

Parameters:
  • input (tl.tensor) – 输入张量。

  • axis (int | Sequence[int]) - 要添加新轴的索引

该函数也可以作为成员函数在tensor上调用, 使用x.expand_dims(...)而非 expand_dims(x, ...)的形式。