torch.autograd.Function.forward¶
- static Function.forward(ctx, *args, **kwargs)¶
定义自定义 autograd 函数的 forward 方法。
此函数应由所有子类重写。 定义forward有两种方式:
用法 1(组合正向和上下文):
@staticmethod def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: pass
它必须接受一个上下文 ctx 作为第一个参数,后面可以跟随任意数量的参数(张量或其他类型)。
参见 组合或单独的 forward() 和 setup_context() 了解更多详情
用法 2(分离的前向和上下文):
@staticmethod def forward(*args: Any, **kwargs: Any) -> Any: pass @staticmethod def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: pass
forward 不再接受 ctx 参数。
相反,您还必须重写
torch.autograd.Function.setup_context()静态方法以处理设置ctx对象。output是前向传播的输出,inputs是前向传播输入的元组。参见 扩展 torch.autograd 了解更多详情
上下文可以用于存储任意数据,这些数据可以在反向传播过程中被检索。张量不应直接存储在ctx上(尽管目前出于向后兼容性考虑,这一点并未强制执行)。相反,如果张量旨在用于
backward(等效地,vjp),则应使用ctx.save_for_backward()保存;如果张量旨在用于jvp,则应使用ctx.save_for_forward()保存。- Return type