From 0dff82c2ec736bbec84b6fd0782ddab8971527a7 Mon Sep 17 00:00:00 2001 From: JZ-LIANG <38102074+JZ-LIANG@users.noreply.github.com> Date: Mon, 9 Aug 2021 11:08:32 +0800 Subject: [PATCH] Recompute: fix bug with transformer attention mask (#34664) --- .../paddle/distributed/fleet/utils/recompute.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 78503baf2fd..89b14258c19 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -145,23 +145,25 @@ class RecomputeFunction(PyLayer): # run backward() with only tensor that requires grad forward_outputs_with_grad = [] - backward_inputs = list(args) + # NOTE In Transformer-like network, if user put the attention mask into the recompute segment output, + # pylayer will force the stop_gradient of attention mask to be False, which will make the number of + # tensor that need grad does not match. + # the following backward_inputs_with_grad is used to avoid this case. + backward_inputs_with_grad = [] for i in range(len(outputs)): if isinstance(outputs[i], core.VarBase) and not outputs[i].stop_gradient: forward_outputs_with_grad.append(outputs[i]) + backward_inputs_with_grad.append(args[i]) + if len(forward_outputs_with_grad) == 0: raise RuntimeError( "none of output has requires_grad=True, this recompute() is not necessary" ) - assert len(backward_inputs) == len( - forward_outputs_with_grad - ), "number of forward outputs is [{}], but the backward got [{}] inputs".format( - len(forward_outputs_with_grad), len(backward_inputs)) - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, backward_inputs) + paddle.autograd.backward(forward_outputs_with_grad, + backward_inputs_with_grad) grads = list(inp._grad_ivar() for inp in detached_inputs if isinstance(inp, core.VarBase)) -- GitLab