Shortcuts

torch.cond

torch.cond(pred, true_fn, false_fn, operands)

有条件地应用true_fnfalse_fn

警告

torch.cond 是 PyTorch 中的一个原型功能。它对输入和输出类型的支持有限,并且目前不支持训练。请期待未来版本的 PyTorch 中更稳定的实现。 了解更多关于功能分类的信息:https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

cond 是一个结构化的控制流操作符。也就是说,它类似于 Python 的 if 语句, 但对 true_fnfalse_fnoperands 有特定的限制,使其能够 使用 torch.compile 和 torch.export 进行捕获。

假设 cond 的参数约束得到满足,cond 等价于以下内容:

def cond(pred, true_branch, false_branch, operands):
    if pred:
        return true_branch(*operands)
    else:
        return false_branch(*operands)
Parameters
  • pred (Union[bool, torch.Tensor]) – 一个布尔表达式或一个包含单个元素的张量,指示应用哪个分支函数。

  • true_fn (可调用函数) – 一个在正在追踪的范围内可调用的函数 (a -> b)。

  • false_fn (可调用函数) – 一个在正在追踪的范围内可调用的函数 (a -> b)。真分支和假分支必须具有一致的输入和输出,这意味着输入必须相同,输出必须具有相同的类型和形状。

  • 操作数元组可能嵌套的字典/列表/元组torch.Tensor)—— 输入到真/假函数的元组。

示例:

def true_fn(x: torch.Tensor):
    return x.cos()
def false_fn(x: torch.Tensor):
    return x.sin()
return cond(x.shape[0] > 4, true_fn, false_fn, (x,))
Restrictions:
  • 条件语句(又名 pred)必须满足以下约束之一:

    • 它是一个只有一个元素且数据类型为torch.bool的torch.Tensor

    • 这是一个布尔表达式,例如 x.shape[0] > 10x.dim() > 1 and x.shape[1] > 10

  • 分支函数(又名 true_fn/false_fn)必须满足以下所有约束条件:

    • 函数签名必须与操作数匹配。

    • 该函数必须返回具有相同元数据的张量,例如形状、数据类型等。

    • 函数不能对输入或全局变量进行原地修改。 (注意:对于中间结果,允许在分支中使用诸如 add_ 的原地张量操作)

警告

时间限制:

  • cond 目前仅支持推理。自动微分将在未来得到支持。

  • 分支的输出必须是一个单一的Tensor。未来将支持张量的Pytree。

优云智算