From 28795771408a6dcd757ed367d348fb0ead5ab507 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 2 Mar 2022 16:40:05 +0800 Subject: [PATCH] run recompute's real backward with amp disabled (#40042) --- python/paddle/distributed/fleet/utils/recompute.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index dccd7f6205..4ccb48ef72 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -182,9 +182,10 @@ class RecomputeFunction(PyLayer): "none of output has requires_grad=True, this recompute() is not necessary" ) - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, - backward_inputs_with_grad) + # actually backward + with paddle.amp.auto_cast(enable=False): + paddle.autograd.backward(forward_outputs_with_grad, + backward_inputs_with_grad) grads = list(inp._grad_ivar() for inp in detached_inputs if isinstance(inp, core.VarBase)) -- GitLab