未验证 提交 8232da7c 编写于 作者: L Leo Chen 提交者: GitHub

share threadpool of executor in dy2static (#46281)

上级 9e917a1e
...@@ -530,16 +530,33 @@ inline void RunProgramGradAPI( ...@@ -530,16 +530,33 @@ inline void RunProgramGradAPI(
auto &interpretercore_info_cache = auto &interpretercore_info_cache =
paddle::framework::InterpreterCoreInfoCache::Instance(); paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) { if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) {
VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; VLOG(2) << "No interpretercore cahce, so create a new interpretercore";
details::ShareTensorsIntoScope(out_grad, global_inner_scope); details::ShareTensorsIntoScope(out_grad, global_inner_scope);
auto interpreter_core = interpreter_core = paddle::framework::CreateInterpreterCoreInfoToCache(
paddle::framework::CreateInterpreterCoreInfoToCache( *backward_program,
*backward_program, place,
place, /*is_grad=*/true,
/*is_grad=*/true, program_id,
program_id, global_inner_scope);
global_inner_scope);
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core.
PADDLE_ENFORCE_EQ(
interpretercore_info_cache.Has(program_id, false),
true,
paddle::platform::errors::NotFound(
"The forward interpretercore of program %d is not found",
program_id));
auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
<< interpreter_core.get();
// get all eager gc vars // get all eager gc vars
std::set<std::string> skip_eager_delete_vars; std::set<std::string> skip_eager_delete_vars;
...@@ -552,17 +569,12 @@ inline void RunProgramGradAPI( ...@@ -552,17 +569,12 @@ inline void RunProgramGradAPI(
interpretercore_info_cache.UpdateSkipEagerDeleteVars( interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, /*is_grad=*/true, skip_eager_delete_vars); program_id, /*is_grad=*/true, skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
if (backward_global_block->OpSize() > 0) {
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front());
interpreter_core->Run({});
}
} else { } else {
VLOG(2) << "Get interpretercore cahce by program:" << program_id; VLOG(2) << "Get interpretercore cahce by program:" << program_id;
auto &cached_value = auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true); interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true);
auto &interpreter_core = cached_value.core_; interpreter_core = cached_value.core_;
// update scope // update scope
details::ShareTensorsIntoScope(out_grad, global_inner_scope); details::ShareTensorsIntoScope(out_grad, global_inner_scope);
if (interpreter_core->GetVariableScope()->GetMutableScope() != if (interpreter_core->GetVariableScope()->GetMutableScope() !=
...@@ -572,14 +584,15 @@ inline void RunProgramGradAPI( ...@@ -572,14 +584,15 @@ inline void RunProgramGradAPI(
global_inner_scope); global_inner_scope);
interpreter_core->reset_scope(global_inner_scope); interpreter_core->reset_scope(global_inner_scope);
} }
}
if (backward_global_block->OpSize() > 0) { if (backward_global_block->OpSize() > 0) {
// Debug info: scope info when run end // Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo( VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front()); out_scope_vec->front());
interpreter_core->Run({}); interpreter_core->Run({});
}
} }
// Step 4. get outputs // Step 4. get outputs
details::ShareTensorsFromScopeWithPartialBlock(x_grad, details::ShareTensorsFromScopeWithPartialBlock(x_grad,
*forward_global_block, *forward_global_block,
......
...@@ -284,7 +284,7 @@ void InterpreterCore::reset_scope(Scope* new_scope) { ...@@ -284,7 +284,7 @@ void InterpreterCore::reset_scope(Scope* new_scope) {
void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) { void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
async_work_queue_ = src->GetWorkQueue(); async_work_queue_ = src->GetWorkQueue();
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << &src VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src.get()
<< ") to InterpreterCore(" << this << ")"; << ") to InterpreterCore(" << this << ")";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册