From 2005b98b4b411745062c8a38285b0f203973b5d6 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 21 Dec 2021 16:13:02 +0800 Subject: [PATCH] fix recompute no grad warning (#38293) --- python/paddle/distributed/fleet/utils/recompute.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 29c541cfb13..c83c350b9e4 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -63,7 +63,8 @@ def swith_rng_state(rng_state): class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - check_recompute_necessary(args) + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function -- GitLab