未验证 提交 8232da7c 编写于 作者: L Leo Chen 提交者: GitHub

share threadpool of executor in dy2static (#46281)

上级 9e917a1e
......@@ -530,16 +530,33 @@ inline void RunProgramGradAPI(
auto &interpretercore_info_cache =
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) {
VLOG(2) << "No interpretercore cahce, so create a new interpretercore";
details::ShareTensorsIntoScope(out_grad, global_inner_scope);
auto interpreter_core =
paddle::framework::CreateInterpreterCoreInfoToCache(
*backward_program,
place,
/*is_grad=*/true,
program_id,
global_inner_scope);
interpreter_core = paddle::framework::CreateInterpreterCoreInfoToCache(
*backward_program,
place,
/*is_grad=*/true,
program_id,
global_inner_scope);
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core.
PADDLE_ENFORCE_EQ(
interpretercore_info_cache.Has(program_id, false),
true,
paddle::platform::errors::NotFound(
"The forward interpretercore of program %d is not found",
program_id));
auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
<< interpreter_core.get();
// get all eager gc vars
std::set<std::string> skip_eager_delete_vars;
......@@ -552,17 +569,12 @@ inline void RunProgramGradAPI(
interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, /*is_grad=*/true, skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
if (backward_global_block->OpSize() > 0) {
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front());
interpreter_core->Run({});
}
} else {
VLOG(2) << "Get interpretercore cahce by program:" << program_id;
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true);
auto &interpreter_core = cached_value.core_;
interpreter_core = cached_value.core_;
// update scope
details::ShareTensorsIntoScope(out_grad, global_inner_scope);
if (interpreter_core->GetVariableScope()->GetMutableScope() !=
......@@ -572,14 +584,15 @@ inline void RunProgramGradAPI(
global_inner_scope);
interpreter_core->reset_scope(global_inner_scope);
}
}
if (backward_global_block->OpSize() > 0) {
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front());
interpreter_core->Run({});
}
if (backward_global_block->OpSize() > 0) {
// Debug info: scope info when run end
VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(
out_scope_vec->front());
interpreter_core->Run({});
}
// Step 4. get outputs
details::ShareTensorsFromScopeWithPartialBlock(x_grad,
*forward_global_block,
......
......@@ -284,7 +284,7 @@ void InterpreterCore::reset_scope(Scope* new_scope) {
void InterpreterCore::ShareWorkQueueFrom(std::shared_ptr<InterpreterCore> src) {
async_work_queue_ = src->GetWorkQueue();
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << &src
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src.get()
<< ") to InterpreterCore(" << this << ")";
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册