diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 78503baf2fd5d2833e557a8d4e2f7271545aeca7..89b14258c195ca77d20e99d0a67686a5f7ddee1f 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -145,23 +145,25 @@ class RecomputeFunction(PyLayer): # run backward() with only tensor that requires grad forward_outputs_with_grad = [] - backward_inputs = list(args) + # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output, + # pylayer will force the stop_gradient of attention mask to be False, which will make the number of + # tensor that need grad does not match. + # the following backward_inputs_with_grad is used to avoid this case. + backward_inputs_with_grad = [] for i in range(len(outputs)): if isinstance(outputs[i], core.VarBase) and not outputs[i].stop_gradient: forward_outputs_with_grad.append(outputs[i]) + backward_inputs_with_grad.append(args[i]) + if len(forward_outputs_with_grad) == 0: raise RuntimeError( "none of output has requires_grad=True, this recompute() is not necessary" ) - assert len(backward_inputs) == len( - forward_outputs_with_grad - ), "number of forward outputs is [{}], but the backward got [{}] inputs".format( - len(forward_outputs_with_grad), len(backward_inputs)) - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + paddle.autograd.backward(forward_outputs_with_grad, + backward_inputs_with_grad) grads = list(inp._grad_ivar() for inp in detached_inputs if isinstance(inp, core.VarBase))