jax.lax.cond#
- jax.lax.cond(pred, true_fun, false_fun, *operands, operand=<object object>)[源代码][源代码]#
有条件地应用
true_fun或false_fun。封装了 XLA 的 Conditional 操作符。
提供的参数类型正确,
cond()具有与以下 Python 实现等效的语义,其中pred必须是标量类型:def cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
与
jax.lax.select()相比,使用cond表示只执行两个分支中的一个(取决于编译器重写和优化)。然而,当通过vmap()转换以对一批谓词进行操作时,cond被转换为select()。- 参数:
pred – 布尔标量类型,指示应用哪个分支函数。
true_fun (Callable) – 函数(A -> B),如果
pred为 True 则应用。false_fun (Callable) – 函数 (A -> B),如果
pred为 False 则应用。operands – 操作数(A)根据
pred输入到任一分支。类型可以是标量、数组,或任何 pytree(嵌套的 Python 元组/列表/字典)。
- 返回:
根据
pred的值,值 (B) 可以是true_fun(*operands)或false_fun(*operands)中的一个。类型可以是标量、数组,或者是任何 pytree(嵌套的 Python 元组/列表/字典)。