jax.lax.select_n#
- jax.lax.select_n(which, *cases)[源代码][源代码]#
从多个案例中选择数组值。
泛化了 XLA 的 Select 操作符。与 XLA 的版本不同,该操作符是可变参数的,并且可以使用整数 pred 从多个情况中进行选择。
- 参数:
which (ArrayLike) – 确定应返回哪种情况。必须是一个包含布尔值或整数值的数组。可以是标量或与
cases形状匹配。对于每个数组元素,which的值决定了cases中的哪一个被采用。which必须在范围[0 .. len(cases))内;对于该范围之外的值,行为由实现定义。*cases (ArrayLike) – 一个非空数组案例列表。所有数组必须具有相同的dtypes和相同的形状。
- 返回:
一个形状和数据类型与案例相等的数组,其值根据
which选择。- 返回类型: