diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index ac9164a77f893caed3f45dcb39b09b5a78fb5522..be19293e6959a0a6b94ec57270dfc137cdd4ba69 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -81,7 +81,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(multi_input[0]->format())); - if (platform::get_cur_thread_id() != -1) { + if (platform::get_cur_mkldnn_session_id() == + platform::kMKLDNNSessionID_Default) { auto tid = std::this_thread::get_id(); std::stringstream ss; ss << tid; diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc index a53471d57eab1ddb23f9da111745c743dc3dfafb..ea0abf930e7f548b93afca937c27fa8d25a35e94 100644 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc @@ -48,7 +48,8 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::MKLDNNHandler::AppendKey(&key, suffix); - if (platform::get_cur_thread_id() != -1) { + if (platform::get_cur_mkldnn_session_id() == + platform::kMKLDNNSessionID_Default) { auto tid = std::this_thread::get_id(); std::stringstream ss; ss << tid; diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 59ba3b63519625fc74fa1a37e5eec2e72e13995a..88829d7a207bdd642ae0af415d69362195433d4b 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -401,12 +401,12 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) } namespace { -// Current thread's id. -thread_local int cur_thread_id = 0; +// Current mkldnn session id. +thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; } -void set_cur_thread_id(int tid) { cur_thread_id = tid; } -int get_cur_thread_id(void) { return cur_thread_id; } +void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } +size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } @@ -415,7 +415,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, BlobMap* pMap = p_blobmap_.get(); std::shared_ptr pBlob = nullptr; - int tid = platform::get_cur_thread_id(); + int tid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); @@ -448,7 +448,7 @@ std::shared_ptr MKLDNNDeviceContext::GetBlob( BlobMap* pMap = p_blobmap_.get(); std::shared_ptr pBlob = nullptr; - int tid = platform::get_cur_thread_id(); + int tid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 0da64aea4297d1b7df0b003d0fdae864d19102b0..91dc8a6887a5812b542407656432d2673e3fdb62 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -381,8 +381,13 @@ struct DefaultDeviceContextType { using KeyBlob = std::unordered_map>; using BlobMap = std::unordered_map>; -void set_cur_thread_id(int); -int get_cur_thread_id(void); +// default mkldnn session id +constexpr size_t kMKLDNNSessionID_Default = 0; +// mkldnn session id for cache clearing mode +constexpr size_t kMKLDNNSessionID_CacheClearing = -1; + +void set_cur_mkldnn_session_id(size_t); +size_t get_cur_mkldnn_session_id(void); class MKLDNNDeviceContext : public CPUDeviceContext { public: diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index e7a7fa2ca36071033a2338aa51d2744e0f6de707..ff4533d9c525c2f5887519edd153d90147115c6e 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -39,7 +39,8 @@ class MKLDNNHandler { std::stringstream ss; ss << tid; key_ = key_common_ + "-t:" + ss.str(); - if (platform::get_cur_thread_id() == -1) { + if (platform::get_cur_mkldnn_session_id() != + platform::kMKLDNNSessionID_Default) { key_ = key_common_; } }