Shortcuts

torch.autograd.Function.forward

static Function.forward(ctx, *args, **kwargs)

定义自定义 autograd 函数的 forward 方法。

此函数应由所有子类重写。 定义forward有两种方式:

用法 1(组合正向和上下文):

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any:
    pass

用法 2(分离的前向和上下文):

@staticmethod
def forward(*args: Any, **kwargs: Any) -> Any:
    pass

@staticmethod
def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
    pass
  • forward 不再接受 ctx 参数。

  • 相反,您还必须重写 torch.autograd.Function.setup_context() 静态方法以处理设置 ctx 对象。 output 是前向传播的输出,inputs 是前向传播输入的元组。

  • 参见 扩展 torch.autograd 了解更多详情

上下文可以用于存储任意数据,这些数据可以在反向传播过程中被检索。张量不应直接存储在ctx上(尽管目前出于向后兼容性考虑,这一点并未强制执行)。相反,如果张量旨在用于backward(等效地,vjp),则应使用ctx.save_for_backward()保存;如果张量旨在用于jvp,则应使用ctx.save_for_forward()保存。

Return type

任意

优云智算