diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 423536b095a402b27a0edb5a5406f4725d6db5b3..1f4439cf1171f9dd51f05d98ff969a8960b8873d 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -73,8 +73,6 @@ class EagerRecomputeFunction(EagerPyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker - if framework._dygraph_tracer()._has_grad: - check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -211,8 +209,6 @@ class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker - if framework._dygraph_tracer()._has_grad: - check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -466,6 +462,9 @@ def recompute(function, *args, **kwargs): raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) + if in_dygraph_mode(): return EagerRecomputeFunction.apply(function, preserve, *args) else: