diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index 4d6e8d93107c33f57d00fc7e441217ba472db604..db3db215e2bca24b0dea48c24db07007be27a01a 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -403,7 +403,7 @@ inline void RunProgramAPI( VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( out_scope_vec->front()); - if (is_test) { + if (is_test || !egr::Controller::Instance().HasGrad()) { VLOG(4) << "is test, set this scope can reused"; global_inner_scope->SetCanReuesd(true); details::GcScope(global_inner_scope); @@ -481,7 +481,7 @@ inline void RunProgramAPI( // Debug info: scope info when run end VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); // Step 5. Drop all children scopes while testing. - if (is_test) { + if (is_test || !egr::Controller::Instance().HasGrad()) { out_scope_vec->front()->DropKids(); } VLOG(2) << "The number of sub scopes after forward: " diff --git a/python/paddle/fluid/dygraph/io.py b/python/paddle/fluid/dygraph/io.py index c5eb445cc610d2fa3d07f239fa3f69737e7d174d..3d85a223f699682885fad5cc334c1c7360997746 100644 --- a/python/paddle/fluid/dygraph/io.py +++ b/python/paddle/fluid/dygraph/io.py @@ -1058,8 +1058,6 @@ def _run_dygraph(instance, input, program_holder): continue persistable_var._set_grad_type(grad_var.type()) - drop_scope_if_no_grad(instance, tmp_scope_vec) - # 3. prepare output, keep same form with inputs outs = output_vars if len(output_vars) == 1: @@ -1067,17 +1065,6 @@ def _run_dygraph(instance, input, program_holder): return outs -def drop_scope_if_no_grad(instance, scope_vec): - tracer = framework._dygraph_tracer() - scope = ( - scope_vec.value().get_scope() - if isinstance(scope_vec, (core.VarBase)) - else scope_vec[0] - ) - if (not instance._is_test) and (not tracer._has_grad): - scope.drop_kids() - - def _run_static_graph(input, program_holder, trace_program): main_program = framework.default_main_program() param_var_names = _get_persistable_var_names(trace_program)