未验证 提交 73c70654 编写于 作者: H hong 提交者: GitHub

[NewIR]Polish new ir interpreter core cache (#56035)

* update

* update cache

* fix compile error
上级 d87d8b02
......@@ -316,8 +316,6 @@ inline void RunProgramAPI(
VLOG(2) << "RunProgramOp use interpretercore to execute program.";
paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
int64_t scope_i = reinterpret_cast<std::uintptr_t>(global_inner_scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
VLOG(4) << "global_inner_scope:" << global_inner_scope;
......@@ -362,7 +360,8 @@ inline void RunProgramAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/false)) {
if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, /*is_grad=*/false)) {
paddle::platform::RecordEvent record_event(
"create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined,
......@@ -420,7 +419,7 @@ inline void RunProgramAPI(
}
interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, false, skip_eager_delete_vars);
program_id, global_inner_scope, false, skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
} else {
paddle::platform::RecordEvent record_event(
......@@ -429,8 +428,8 @@ inline void RunProgramAPI(
1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id;
// Step 1. get cache interpretercore
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false);
auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, global_inner_scope, /*is_grad=*/false);
interpreter_core = cached_value.core_;
// Step 2. update scope for cache interpretercore
details::ShareTensorsIntoScope(x, global_inner_scope);
......@@ -500,8 +499,6 @@ inline void RunProgramGradAPI(
paddle::framework::Scope *global_inner_scope = out_scope_vec->front();
int64_t program_id = PADDLE_GET_CONST(int64_t, attrs.at("program_id"));
int64_t scope_i = reinterpret_cast<std::uintptr_t>(global_inner_scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
auto place = egr::Controller::Instance().GetExpectedPlace();
VLOG(2) << "RunProgramGradOp use interpretercore to execute program.";
......@@ -519,7 +516,8 @@ inline void RunProgramGradAPI(
paddle::framework::InterpreterCoreInfoCache::Instance();
std::shared_ptr<paddle::framework::InterpreterCore> interpreter_core =
nullptr;
if (!interpretercore_info_cache.Has(program_id, /*is_grad=*/true)) {
if (!interpretercore_info_cache.Has(
program_id, global_inner_scope, /*is_grad=*/true)) {
paddle::platform::RecordEvent record_event(
"create_new_interpretercore",
paddle::platform::TracerEventType::UserDefined,
......@@ -555,9 +553,10 @@ inline void RunProgramGradAPI(
// share threadpool
// NOTE(zhiqiu): this only works interpreter_core is executed strictly
// after the related fwd_interpreter_core.
if (interpretercore_info_cache.Has(program_id, false)) {
if (interpretercore_info_cache.Has(program_id, global_inner_scope, false)) {
auto fwd_interpreter_core =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/false)
interpretercore_info_cache
.GetMutable(program_id, global_inner_scope, /*is_grad=*/false)
.core_;
interpreter_core->ShareWorkQueueFrom(fwd_interpreter_core);
VLOG(4) << "Share workqueue from " << fwd_interpreter_core.get() << " to "
......@@ -581,7 +580,10 @@ inline void RunProgramGradAPI(
&skip_eager_delete_vars);
interpreter_core->SetSkipGcVars(skip_eager_delete_vars);
interpretercore_info_cache.UpdateSkipEagerDeleteVars(
program_id, /*is_grad=*/true, skip_eager_delete_vars);
program_id,
global_inner_scope,
/*is_grad=*/true,
skip_eager_delete_vars);
VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size();
} else {
paddle::platform::RecordEvent record_event(
......@@ -589,8 +591,8 @@ inline void RunProgramGradAPI(
paddle::platform::TracerEventType::UserDefined,
1);
VLOG(2) << "Get interpretercore cahce by program:" << program_id;
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, /*is_grad=*/true);
auto &cached_value = interpretercore_info_cache.GetMutable(
program_id, global_inner_scope, /*is_grad=*/true);
interpreter_core = cached_value.core_;
// update scope
......
......@@ -312,7 +312,7 @@ std::shared_ptr<InterpreterCore> CreateProgramInterpreterCoreInfoToCache(
place, program_desc.Block(0), scope, execution_config));
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad);
interpretercore_info_cache.GetMutable(program_id, scope, is_grad);
cached_value.core_ = core;
return core;
}
......@@ -340,7 +340,7 @@ std::shared_ptr<InterpreterCore> CreateNewIRInterpreterCoreInfoToCache(
place, {}, std::move(ir_program), scope, execution_config));
auto &cached_value =
interpretercore_info_cache.GetMutable(program_id, is_grad);
interpretercore_info_cache.GetMutable(program_id, scope, is_grad);
cached_value.core_ = core;
return core;
}
......
......@@ -187,26 +187,33 @@ class InterpreterCoreInfoCache {
public:
static InterpreterCoreInfoCache& Instance();
bool Has(int64_t program_id, bool is_grad) {
bool Has(int64_t program_id, const framework::Scope* scope, bool is_grad) {
int64_t scope_i = reinterpret_cast<std::uintptr_t>(scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
return info_map_.find(program_id) != info_map_.end() &&
info_map_[program_id].IsAvailable(is_grad);
}
InterpreterCoreInfo::CacheValue& GetMutable(int64_t program_id,
const framework::Scope* scope,
bool is_grad) {
int64_t scope_i = reinterpret_cast<std::uintptr_t>(scope);
program_id += 0x9e3779b9 + (program_id << 6) + (scope_i >> 2);
return info_map_[program_id].GetMutable(is_grad);
}
void UpdateSkipEagerDeleteVars(int64_t program_id,
const framework::Scope* scope,
bool is_grad,
const std::set<std::string>& skip_vars) {
auto& cached_value = GetMutable(program_id, is_grad);
auto& cached_value = GetMutable(program_id, scope, is_grad);
cached_value.skip_eager_delete_vars_ = std::move(skip_vars);
}
std::set<std::string>& GetSkipEagerDeleteVars(int64_t program_id,
const framework::Scope* scope,
bool is_grad) {
auto& cached_value = GetMutable(program_id, is_grad);
auto& cached_value = GetMutable(program_id, scope, is_grad);
return cached_value.skip_eager_delete_vars_;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册