未验证 提交 73c70654 编写于 作者: H hong 提交者: GitHub

[NewIR]Polish new ir interpreter core cache (#56035)

* update

* update cache

* fix compile error
上级 d87d8b02
...@@ -316,8 +316,6 @@ inline void RunProgramAPI( ...@@ -316,8 +316,6 @@ inline void RunProgramAPI(
VLOG(2) << "RunProgramOp use interpretercore to execute program."; VLOG(2) << "RunProgramOp use interpretercore to execute program.";
paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
int64_t scope_i = reinterpret_cast<std::uintptr_t>(global_inner_scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
VLOG(4) << "global_inner_scope:" << global_inner_scope; VLOG(4) << "global_inner_scope:" << global_inner_scope;
...@@ -362,7 +360,8 @@ inline void RunProgramAPI( ...@@ -362,7 +360,8 @@ inline void RunProgramAPI(
paddle::framework::InterpreterCoreInfoCache::Instance(); paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core = std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr; nullptr;
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) { if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, /*is_grad=*/false)) {
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
"create_new_interpretercore", "create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined, paddle::platform::TracerEventType::UserDefined,
...@@ -420,7 +419,7 @@ inline void RunProgramAPI( ...@@ -420,7 +419,7 @@ inline void RunProgramAPI(
} }
interpretercore_info_cache.UpdateSkipEagerDeleteVars( interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, false, skip_eager_delete_vars); program_id, global_inner_scope, false, skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
} else { } else {
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
...@@ -429,8 +428,8 @@ inline void RunProgramAPI( ...@@ -429,8 +428,8 @@ inline void RunProgramAPI(
1); 1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id; VLOG(2) << "Get interpretercore cahce by program:" << program_id;
// Step 1. get cache interpretercore // Step 1. get cache interpretercore
auto &cached_value = auto &cached_value = interpretercore_info_cache.GetMutable(
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false); program_id, global_inner_scope, /*is_grad=*/false);
interpreter_core = cached_value.core_; interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore // Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScope(x, global_inner_scope); details::ShareTensorsIntoScope(x, global_inner_scope);
...@@ -500,8 +499,6 @@ inline void RunProgramGradAPI( ...@@ -500,8 +499,6 @@ inline void RunProgramGradAPI(
paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id"));
int64_t scope_i = reinterpret_cast<std::uintptr_t>(global_inner_scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
auto place = egr::Controller::Instance().GetExpectedPlace(); auto place = egr::Controller::Instance().GetExpectedPlace();
VLOG(2) << "RunProgramGradOp use interpretercore to execute program."; VLOG(2) << "RunProgramGradOp use interpretercore to execute program.";
...@@ -519,7 +516,8 @@ inline void RunProgramGradAPI( ...@@ -519,7 +516,8 @@ inline void RunProgramGradAPI(
paddle::framework::InterpreterCoreInfoCache::Instance(); paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core = std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr; nullptr;
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) { if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, /*is_grad=*/true)) {
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
"create_new_interpretercore", "create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined, paddle::platform::TracerEventType::UserDefined,
...@@ -555,9 +553,10 @@ inline void RunProgramGradAPI( ...@@ -555,9 +553,10 @@ inline void RunProgramGradAPI(
// share threadpool // share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly // NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core. // after the related fwd_interpreter_core.
if (interpretercore_info_cache.Has(program_id, false)) { if (interpretercore_info_cache.Has(program_id, global_inner_scope, false)) {
auto fwd_interpreter_core = auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false) interpretercore_info_cache
.GetMutable(program_id, global_inner_scope, /*is_grad=*/false)
.core_; .core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
...@@ -581,7 +580,10 @@ inline void RunProgramGradAPI( ...@@ -581,7 +580,10 @@ inline void RunProgramGradAPI(
&skip_eager_delete_vars); &skip_eager_delete_vars);
interpreter_core->SetSkipGcVars(skip_eager_delete_vars); interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
interpretercore_info_cache.UpdateSkipEagerDeleteVars( interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, /*is_grad=*/true, skip_eager_delete_vars); program_id,
global_inner_scope,
/*is_grad=*/true,
skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
} else { } else {
paddle::platform::RecordEvent record_event( paddle::platform::RecordEvent record_event(
...@@ -589,8 +591,8 @@ inline void RunProgramGradAPI( ...@@ -589,8 +591,8 @@ inline void RunProgramGradAPI(
paddle::platform::TracerEventType::UserDefined, paddle::platform::TracerEventType::UserDefined,
1); 1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id; VLOG(2) << "Get interpretercore cahce by program:" << program_id;
auto &cached_value = auto &cached_value = interpretercore_info_cache.GetMutable(
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true); program_id, global_inner_scope, /*is_grad=*/true);
interpreter_core = cached_value.core_; interpreter_core = cached_value.core_;
// update scope // update scope
......
...@@ -312,7 +312,7 @@ std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache( ...@@ -312,7 +312,7 @@ std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache(
place, program_desc.Block(0), scope, execution_config)); place, program_desc.Block(0), scope, execution_config));
auto &cached_value = auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad); interpretercore_info_cache.GetMutable(program_id, scope, is_grad);
cached_value.core_ = core; cached_value.core_ = core;
return core; return core;
} }
...@@ -340,7 +340,7 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache( ...@@ -340,7 +340,7 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
place, {}, std::move(ir_program), scope, execution_config)); place, {}, std::move(ir_program), scope, execution_config));
auto &cached_value = auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad); interpretercore_info_cache.GetMutable(program_id, scope, is_grad);
cached_value.core_ = core; cached_value.core_ = core;
return core; return core;
} }
......
...@@ -187,26 +187,33 @@ class InterpreterCoreInfoCache { ...@@ -187,26 +187,33 @@ class InterpreterCoreInfoCache {
public: public:
static InterpreterCoreInfoCache& Instance(); static InterpreterCoreInfoCache& Instance();
bool Has(int64_t program_id, bool is_grad) { bool Has(int64_t program_id, const framework::Scope* scope, bool is_grad) {
int64_t scope_i = reinterpret_cast<std::uintptr_t>(scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
return info_map_.find(program_id) != info_map_.end() && return info_map_.find(program_id) != info_map_.end() &&
info_map_[program_id].IsAvailable(is_grad); info_map_[program_id].IsAvailable(is_grad);
} }
InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id, InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id,
const framework::Scope* scope,
bool is_grad) { bool is_grad) {
int64_t scope_i = reinterpret_cast<std::uintptr_t>(scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
return info_map_[program_id].GetMutable(is_grad); return info_map_[program_id].GetMutable(is_grad);
} }
void UpdateSkipEagerDeleteVars(int64_t program_id, void UpdateSkipEagerDeleteVars(int64_t program_id,
const framework::Scope* scope,
bool is_grad, bool is_grad,
const std::set<std::string>& skip_vars) { const std::set<std::string>& skip_vars) {
auto& cached_value = GetMutable(program_id, is_grad); auto& cached_value = GetMutable(program_id, scope, is_grad);
cached_value.skip_eager_delete_vars_ = std::move(skip_vars); cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
} }
std::set<std::string>& GetSkipEagerDeleteVars(int64_t program_id, std::set<std::string>& GetSkipEagerDeleteVars(int64_t program_id,
const framework::Scope* scope,
bool is_grad) { bool is_grad) {
auto& cached_value = GetMutable(program_id, is_grad); auto& cached_value = GetMutable(program_id, scope, is_grad);
return cached_value.skip_eager_delete_vars_; return cached_value.skip_eager_delete_vars_;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册