Shortcuts

torch.autograd.function.FunctionCtx.mark_non_differentiable

FunctionCtx.mark_non_differentiable(*args)[源代码]

将输出标记为不可微分。

这最多只能调用一次,仅在 forward() 方法内部调用,并且所有参数都应该是张量输出。

这将标记输出为不需要梯度,从而提高反向传播计算的效率。你仍然需要在 backward() 中接受每个输出的梯度,但它总是一个与相应输出形状相同的零张量。

This is used e.g. for indices returned from a sort. See example::
>>> class Func(Function):
>>>     @staticmethod
>>>     def forward(ctx, x):
>>>         sorted, idx = x.sort()
>>>         ctx.mark_non_differentiable(idx)
>>>         ctx.save_for_backward(x, idx)
>>>         return sorted, idx
>>>
>>>     @staticmethod
>>>     @once_differentiable
>>>     def backward(ctx, g1, g2):  # 仍然需要接受 g2
>>>         x, idx = ctx.saved_tensors
>>>         grad_input = torch.zeros_like(x)
>>>         grad_input.index_add_(0, idx, g1)
>>>         return grad_input
优云智算