From 5e97651e40a85a459f9129e5f618cb24c8103dc6 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Mon, 24 Oct 2022 20:14:15 +0800 Subject: [PATCH] fix warning infos of recompute hybrid in eager mode (#47288) --- python/paddle/distributed/fleet/recompute/recompute_hybrid.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py index 7f0d22726e5..2e53c6ee174 100644 --- a/python/paddle/distributed/fleet/recompute/recompute_hybrid.py +++ b/python/paddle/distributed/fleet/recompute/recompute_hybrid.py @@ -86,7 +86,6 @@ class _HPRecomputeFunction(PyLayer): *args, **kwargs ): - check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -264,6 +263,9 @@ def recompute_hybrid(ctx, function, *args, **kwargs): offload = ctx.get('offload', False) partition = ctx.get('partition', False) + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) + all_outputs = [] _HPRecomputeFunction.apply( function, all_outputs, mp_group, offload, partition, *args, **kwargs -- GitLab