diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index dccd7f6205302663117c0957c19138270bf32feb..4ccb48ef72e714e1739a6e9da88374cd4ce17ed4 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -182,9 +182,10 @@ class RecomputeFunction(PyLayer): "none of output has requires_grad=True, this recompute() is not necessary" ) - # actually backward - paddle.autograd.backward(forward_outputs_with_grad, - backward_inputs_with_grad) + # actually backward + with paddle.amp.auto_cast(enable=False): + paddle.autograd.backward(forward_outputs_with_grad, + backward_inputs_with_grad) grads = list(inp._grad_ivar() for inp in detached_inputs if isinstance(inp, core.VarBase))