diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index b30a9806eb19ee12d2a70afe3ca806224b0f75d6..f1eccb351ef68697c748814efb2987041b0da8d9 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -321,7 +321,8 @@ std::vector> Executor::Prepare( } void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, - bool create_local_scope, bool create_vars) { + bool create_local_scope, bool create_vars, + bool keep_kids) { Scope* local_scope = scope; if (create_vars) { if (create_local_scope) { @@ -344,12 +345,20 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } } platform::DeviceContextPool::Instance().Get(place_)->Wait(); - if (create_vars && create_local_scope) { + if (local_scope != scope) { scope->DeleteScope(local_scope); } else { - // Delete the local scopes created in operators. - scope->DropKids(); + if (!keep_kids) { + // By default, we should delete all kid scopes after run executor because + // some operators may create local scope when running, such as while_op. + // But when while_op also create a local executor to run it's sub block, + // the sub scopes it created should not be dropped immediately, because + // while_grad_op will use some variables created during while_op run, so + // we need to keep the kids and wait for the outer executor to drop them. + scope->DropKids(); + } } + if (FLAGS_benchmark) { VLOG(2) << "-------------------------------------------------------"; VLOG(2) << "Memory used after deleting local scope: " diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 67a0761dac2a9adcdd0ce2b218c4aa505d688d56..3aa5ffef69cd29681f248e915575c5715ad0d3fa 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -78,7 +78,7 @@ class Executor { void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, bool create_local_scope = true, - bool create_vars = true); + bool create_vars = true, bool keep_kids = false); void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, std::map* feed_targets,