From 8232da7c7bd29aa891cb345de3c81c39fd1b971a Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 21 Sep 2022 11:09:29 +0800 Subject: [PATCH] share threadpool of executor in dy2static (#46281) --- .../eager/to_static/run_program_op_node.h | 53 ++++++++++++------- .../framework/new_executor/interpretercore.cc | 2 +- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index b5f0278e2d0..7c544cb0c14 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -530,16 +530,33 @@ inline void RunProgramGradAPI( auto &interpretercore_info_cache = paddle::framework::InterpreterCoreInfoCache::Instance(); + std::shared_ptr interpreter_core = + nullptr; if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) { VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; details::ShareTensorsIntoScope(out_grad, global_inner_scope); - auto interpreter_core = - paddle::framework::CreateInterpreterCoreInfoToCache( - *backward_program, - place, - /*is_grad=*/true, - program_id, - global_inner_scope); + interpreter_core = paddle::framework::CreateInterpreterCoreInfoToCache( + *backward_program, + place, + /*is_grad=*/true, + program_id, + global_inner_scope); + + // share threadpool + // NOTE(zhiqiu): this only works interpreter_core is executed strictly + // after the related fwd_interpreter_core. + PADDLE_ENFORCE_EQ( + interpretercore_info_cache.Has(program_id, false), + true, + paddle::platform::errors::NotFound( + "The forward interpretercore of program %d is not found", + program_id)); + auto fwd_interpreter_core = + interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false) + .core_; + interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); + VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " + << interpreter_core.get(); // get all eager gc vars std::set skip_eager_delete_vars; @@ -552,17 +569,12 @@ inline void RunProgramGradAPI( interpretercore_info_cache.UpdateSkipEagerDeleteVars( program_id, /*is_grad=*/true, skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); - if (backward_global_block->OpSize() > 0) { - // Debug info: scope info when run end - VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( - out_scope_vec->front()); - interpreter_core->Run({}); - } } else { VLOG(2) << "Get interpretercore cahce by program:" << program_id; auto &cached_value = interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true); - auto &interpreter_core = cached_value.core_; + interpreter_core = cached_value.core_; + // update scope details::ShareTensorsIntoScope(out_grad, global_inner_scope); if (interpreter_core->GetVariableScope()->GetMutableScope() != @@ -572,14 +584,15 @@ inline void RunProgramGradAPI( global_inner_scope); interpreter_core->reset_scope(global_inner_scope); } + } - if (backward_global_block->OpSize() > 0) { - // Debug info: scope info when run end - VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( - out_scope_vec->front()); - interpreter_core->Run({}); - } + if (backward_global_block->OpSize() > 0) { + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( + out_scope_vec->front()); + interpreter_core->Run({}); } + // Step 4. get outputs details::ShareTensorsFromScopeWithPartialBlock(x_grad, *forward_global_block, diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index c379e135b16..854f2da0d22 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -284,7 +284,7 @@ void InterpreterCore::reset_scope(Scope* new_scope) { void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr src) { async_work_queue_ = src->GetWorkQueue(); - VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << &src + VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src.get() << ") to InterpreterCore(" << this << ")"; } -- GitLab