torch.autograd.function.FunctionCtx.set_materialize_grads¶
- FunctionCtx.set_materialize_grads(value)[源代码]¶
设置是否具体化梯度张量。默认是
True。这应该仅在
forward()方法内部调用如果
True,在调用backward()和jvp()方法之前,未定义的梯度张量将被扩展为全零张量。- Example::
>>> class SimpleFunc(Function): >>> @staticmethod >>> def forward(ctx, x): >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> return g1 + g2 # 不需要检查None >>> >>> # 我们修改SimpleFunc以处理非具体化的梯度输出 >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> ctx.set_materialize_grads(False) >>> ctx.save_for_backward(x) >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> x, = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> if g1 is not None: # 现在我们必须检查None >>> grad_input += g1 >>> if g2 is not None: >>> grad_input += g2 >>> return grad_input >>> >>> a = torch.tensor(1., requires_grad=True) >>> b, _ = Func.apply(a) # 导致g2未定义