Shortcuts

torch.autograd.function.FunctionCtx.set_materialize_grads

FunctionCtx.set_materialize_grads(value)[源代码]

设置是否具体化梯度张量。默认是 True

这应该仅在 forward() 方法内部调用

如果 True,在调用 backward()jvp() 方法之前,未定义的梯度张量将被扩展为全零张量。

Example::
>>> class SimpleFunc(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         return g1 + g2  # 不需要检查None
>>>
>>> # 我们修改SimpleFunc以处理非具体化的梯度输出
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         ctx.set_materialize_grads(False)
>>>         ctx.save_for_backward(x)
>>>         return x.clone(), x.clone()
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):
>>>         x, = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         if g1 is not None:  # 现在我们必须检查None
>>>             grad_input += g1
>>>         if g2 is not None:
>>>             grad_input += g2
>>>         return grad_input
>>>
>>> a = torch.tensor(1., requires_grad=True)
>>> b, _ = Func.apply(a)  # 导致g2未定义
优云智算