From b68c4a1e6dc442a248f8f51650249f0559dbd5bd Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Thu, 27 Oct 2022 19:18:04 +0800 Subject: [PATCH] [Dy2St]Fix abnormal growth of memory in train mode and no_grad for Dy2St (#47398) Fix abnormal growth of memory in train mode and no_grad for Dy2St --- paddle/fluid/eager/to_static/run_program_op_node.h | 4 ++-- python/paddle/fluid/dygraph/io.py | 13 ------------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 4d6e8d9310..db3db215e2 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -403,7 +403,7 @@ inline void RunProgramAPI( VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( out_scope_vec->front()); - if (is_test) { + if (is_test || !egr::Controller::Instance().HasGrad()) { VLOG(4) << "is test, set this scope can reused"; global_inner_scope->SetCanReuesd(true); details::GcScope(global_inner_scope); @@ -481,7 +481,7 @@ inline void RunProgramAPI( // Debug info: scope info when run end VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); // Step 5. Drop all children scopes while testing. - if (is_test) { + if (is_test || !egr::Controller::Instance().HasGrad()) { out_scope_vec->front()->DropKids(); } VLOG(2) << "The number of sub scopes after forward: " diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index c5eb445cc6..3d85a223f6 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -1058,8 +1058,6 @@ def _run_dygraph(instance, input, program_holder): continue persistable_var._set_grad_type(grad_var.type()) - drop_scope_if_no_grad(instance, tmp_scope_vec) - # 3. prepare output, keep same form with inputs outs = output_vars if len(output_vars) == 1: @@ -1067,17 +1065,6 @@ def _run_dygraph(instance, input, program_holder): return outs -def drop_scope_if_no_grad(instance, scope_vec): - tracer = framework._dygraph_tracer() - scope = ( - scope_vec.value().get_scope() - if isinstance(scope_vec, (core.VarBase)) - else scope_vec[0] - ) - if (not instance._is_test) and (not tracer._has_grad): - scope.drop_kids() - - def _run_static_graph(input, program_holder, trace_program): main_program = framework.default_main_program() param_var_names = _get_persistable_var_names(trace_program) -- GitLab