diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 29c541cfb13bfb1add79b7909c4b1aab18641aab..c83c350b9e4f781cb1a17179b01d5ec5dea70397 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -63,7 +63,8 @@ def swith_rng_state(rng_state): class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): - check_recompute_necessary(args) + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function