From 73c706546a2099e974a5a5bceb460cb2aecf502b Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Tue, 8 Aug 2023 16:24:37 +0800 Subject: [PATCH] [NewIR]Polish new ir interpreter core cache (#56035) * update * update cache * fix compile error --- .../eager/to_static/run_program_op_node.h | 30 ++++++++++--------- paddle/fluid/framework/executor_cache.cc | 4 +-- paddle/fluid/framework/executor_cache.h | 13 ++++++-- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index a166792d8b8..9638022eb81 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -316,8 +316,6 @@ inline void RunProgramAPI( VLOG(2) << "RunProgramOp use interpretercore to execute program."; paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); - int64_t scope_i = reinterpret_cast(global_inner_scope); - program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2); VLOG(4) << "global_inner_scope:" << global_inner_scope; @@ -362,7 +360,8 @@ inline void RunProgramAPI( paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) { + if (!interpretercore_info_cache.Has( + program_id, global_inner_scope, /*is_grad=*/false)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -420,7 +419,7 @@ inline void RunProgramAPI( } interpretercore_info_cache.UpdateSkipEagerDeleteVars( - program_id, false, skip_eager_delete_vars); + program_id, global_inner_scope, false, skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); } else { paddle::platform::RecordEvent record_event( @@ -429,8 +428,8 @@ inline void RunProgramAPI( 1); VLOG(2) << "Get interpretercore cahce by program:" << program_id; // Step 1. get cache interpretercore - auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false); + auto &cached_value = interpretercore_info_cache.GetMutable( + program_id, global_inner_scope, /*is_grad=*/false); interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScope(x, global_inner_scope); @@ -500,8 +499,6 @@ inline void RunProgramGradAPI( paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id")); - int64_t scope_i = reinterpret_cast(global_inner_scope); - program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2); auto place = egr::Controller::Instance().GetExpectedPlace(); VLOG(2) << "RunProgramGradOp use interpretercore to execute program."; @@ -519,7 +516,8 @@ inline void RunProgramGradAPI( paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) { + if (!interpretercore_info_cache.Has( + program_id, global_inner_scope, /*is_grad=*/true)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -555,9 +553,10 @@ inline void RunProgramGradAPI( // share threadpool // NOTE(zhiqiu): this only works interpreter_core is executed strictly // after the related fwd_interpreter_core. - if (interpretercore_info_cache.Has(program_id, false)) { + if (interpretercore_info_cache.Has(program_id, global_inner_scope, false)) { auto fwd_interpreter_core = - interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false) + interpretercore_info_cache + .GetMutable(program_id, global_inner_scope, /*is_grad=*/false) .core_; interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core); VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to " @@ -581,7 +580,10 @@ inline void RunProgramGradAPI( &skip_eager_delete_vars); interpreter_core->SetSkipGcVars(skip_eager_delete_vars); interpretercore_info_cache.UpdateSkipEagerDeleteVars( - program_id, /*is_grad=*/true, skip_eager_delete_vars); + program_id, + global_inner_scope, + /*is_grad=*/true, + skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); } else { paddle::platform::RecordEvent record_event( @@ -589,8 +591,8 @@ inline void RunProgramGradAPI( paddle::platform::TracerEventType::UserDefined, 1); VLOG(2) << "Get interpretercore cahce by program:" << program_id; - auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true); + auto &cached_value = interpretercore_info_cache.GetMutable( + program_id, global_inner_scope, /*is_grad=*/true); interpreter_core = cached_value.core_; // update scope diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 4d39c1a533d..67646dec6db 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -312,7 +312,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( place, program_desc.Block(0), scope, execution_config)); auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, is_grad); + interpretercore_info_cache.GetMutable(program_id, scope, is_grad); cached_value.core_ = core; return core; } @@ -340,7 +340,7 @@ std::shared_ptr CreateNewIRInterpreterCoreInfoToCache( place, {}, std::move(ir_program), scope, execution_config)); auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, is_grad); + interpretercore_info_cache.GetMutable(program_id, scope, is_grad); cached_value.core_ = core; return core; } diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index d6610a89500..3a5930c0c47 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -187,26 +187,33 @@ class InterpreterCoreInfoCache { public: static InterpreterCoreInfoCache& Instance(); - bool Has(int64_t program_id, bool is_grad) { + bool Has(int64_t program_id, const framework::Scope* scope, bool is_grad) { + int64_t scope_i = reinterpret_cast(scope); + program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2); return info_map_.find(program_id) != info_map_.end() && info_map_[program_id].IsAvailable(is_grad); } InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id, + const framework::Scope* scope, bool is_grad) { + int64_t scope_i = reinterpret_cast(scope); + program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2); return info_map_[program_id].GetMutable(is_grad); } void UpdateSkipEagerDeleteVars(int64_t program_id, + const framework::Scope* scope, bool is_grad, const std::set& skip_vars) { - auto& cached_value = GetMutable(program_id, is_grad); + auto& cached_value = GetMutable(program_id, scope, is_grad); cached_value.skip_eager_delete_vars_ = std::move(skip_vars); } std::set& GetSkipEagerDeleteVars(int64_t program_id, + const framework::Scope* scope, bool is_grad) { - auto& cached_value = GetMutable(program_id, is_grad); + auto& cached_value = GetMutable(program_id, scope, is_grad); return cached_value.skip_eager_delete_vars_; } -- GitLab