triton.language.where

triton.language.where(condition, x, y)

根据condition条件,返回一个由xy中的元素组成的张量。

请注意,无论condition的值如何,xy总是会被计算。

如果想避免意外的内存操作,请改用triton.loadtriton.store中的mask参数。

xy 的形状都会被广播为 condition 的形状。 xy 必须具有相同的数据类型。

Parameters:
  • condition (Block of triton.bool) – 当为True(非零)时返回x,否则返回y。

  • x – 在条件为True的索引处选中的值。

  • y – 在条件为False的索引处选中的值。