From 8f5fffca0a4b50f12bf2dae374a4ad49402a686f Mon Sep 17 00:00:00 2001 From: Leo Zhao <48052473+LeoZhao-Intel@users.noreply.github.com> Date: Tue, 2 Jul 2019 22:57:44 +0800 Subject: [PATCH] rename mkldnn set/get_cur_thread_id() to set/get_cur_mkldnn_session_id() (#18453) * rename mkldnn set/get_cur_thread_id() to set/get_cur_mkldnn_session_id() test=develop * update session id definition and adjust logic for default behavior test=develop * reset logic in mkldnn reuse as most of cases work in default. test=develop --- paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc | 3 ++- paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc | 3 ++- paddle/fluid/platform/device_context.cc | 12 ++++++------ paddle/fluid/platform/device_context.h | 9 +++++++-- paddle/fluid/platform/mkldnn_reuse.h | 3 ++- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index ac9164a77f8..be19293e695 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -81,7 +81,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(multi_input[0]->format())); - if (platform::get_cur_thread_id() != -1) { + if (platform::get_cur_mkldnn_session_id() == + platform::kMKLDNNSessionID_Default) { auto tid = std::this_thread::get_id(); std::stringstream ss; ss << tid; diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index a53471d57ea..ea0abf930e7 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -48,7 +48,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::MKLDNNHandler::AppendKey(&key, suffix); - if (platform::get_cur_thread_id() != -1) { + if (platform::get_cur_mkldnn_session_id() == + platform::kMKLDNNSessionID_Default) { auto tid = std::this_thread::get_id(); std::stringstream ss; ss << tid; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 59ba3b63519..88829d7a207 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -401,12 +401,12 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) } namespace { -// Current thread's id. -thread_local int cur_thread_id = 0; +// Current mkldnn session id. +thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; } -void set_cur_thread_id(int tid) { cur_thread_id = tid; } -int get_cur_thread_id(void) { return cur_thread_id; } +void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } +size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } @@ -415,7 +415,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, BlobMap* pMap = p_blobmap_.get(); std::shared_ptr pBlob = nullptr; - int tid = platform::get_cur_thread_id(); + int tid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); @@ -448,7 +448,7 @@ std::shared_ptr MKLDNNDeviceContext::GetBlob( BlobMap* pMap = p_blobmap_.get(); std::shared_ptr pBlob = nullptr; - int tid = platform::get_cur_thread_id(); + int tid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 0da64aea429..91dc8a6887a 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -381,8 +381,13 @@ struct DefaultDeviceContextType { using KeyBlob = std::unordered_map>; using BlobMap = std::unordered_map>; -void set_cur_thread_id(int); -int get_cur_thread_id(void); +// default mkldnn session id +constexpr size_t kMKLDNNSessionID_Default = 0; +// mkldnn session id for cache clearing mode +constexpr size_t kMKLDNNSessionID_CacheClearing = -1; + +void set_cur_mkldnn_session_id(size_t); +size_t get_cur_mkldnn_session_id(void); class MKLDNNDeviceContext : public CPUDeviceContext { public: diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index e7a7fa2ca36..ff4533d9c52 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -39,7 +39,8 @@ class MKLDNNHandler { std::stringstream ss; ss << tid; key_ = key_common_ + "-t:" + ss.str(); - if (platform::get_cur_thread_id() == -1) { + if (platform::get_cur_mkldnn_session_id() != + platform::kMKLDNNSessionID_Default) { key_ = key_common_; } } -- GitLab