From ed6f1f90e9e4e910f8fefc54bdbdcbb21359c104 Mon Sep 17 00:00:00 2001 From: Haohongxiang <86215757+haohongxiang@users.noreply.github.com> Date: Tue, 14 Jun 2022 19:20:14 +0800 Subject: [PATCH] fix warning infos of recompute (#43495) --- python/paddle/distributed/fleet/utils/recompute.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/fleet/utils/recompute.py b/python/paddle/distributed/fleet/utils/recompute.py index 423536b095a..1f4439cf117 100755 --- a/python/paddle/distributed/fleet/utils/recompute.py +++ b/python/paddle/distributed/fleet/utils/recompute.py @@ -73,8 +73,6 @@ class EagerRecomputeFunction(EagerPyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker - if framework._dygraph_tracer()._has_grad: - check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -211,8 +209,6 @@ class RecomputeFunction(PyLayer): @staticmethod def forward(ctx, run_function, preserve_rng_state, *args): from paddle.distributed.fleet.meta_parallel.parallel_layers.random import get_rng_state_tracker - if framework._dygraph_tracer()._has_grad: - check_recompute_necessary(args) # store for recomputing ctx.run_function = run_function @@ -466,6 +462,9 @@ def recompute(function, *args, **kwargs): raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) + if framework._dygraph_tracer()._has_grad: + check_recompute_necessary(args) + if in_dygraph_mode(): return EagerRecomputeFunction.apply(function, preserve, *args) else: -- GitLab