diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 7f0d22726e59635f70e7e688bec530ddf2ee7954..2e53c6ee174e063a741e41583b7370ac5e0b8c87 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -86,7 +86,6 @@ class _HPRecomputeFunction(PyLayer): *args, **kwargs ): - check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -264,6 +263,9 @@ def recompute_hybrid(ctx, function, *args, **kwargs): offload = ctx.get('offload', False) partition = ctx.get('partition', False) + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) + all_outputs = [] _HPRecomputeFunction.apply( function, all_outputs, mp_group, offload, partition, *args, **kwargs