diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index d40a898b859e3f90e96c89f139c0260b949310b8..b4d8c3e32a051a147ff3791260f9e3e60db47ceb 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -528,6 +528,8 @@ class MKLDNNDeviceContextThreadLocals { // Recently registered data_format. This is needed to // know for converting MKL-DNN Tensor to non MKL-DNN paddle::framework::DataLayout cur_paddle_data_layout; + std::string key_suffix; // Key identifying current Executor + bool key_attach_thread_id = true; Body(); void set_cur_mkldnn_session_id(size_t sid); @@ -537,6 +539,10 @@ class MKLDNNDeviceContextThreadLocals { void set_cur_paddle_data_layout(framework::DataLayout dl); framework::DataLayout get_cur_paddle_data_layout(void); void log_lib_version(void); + void set_key_suffix(const std::string& suffix) { key_suffix = suffix; } + const std::string& get_key_suffix(void) const { return key_suffix; } + void disable_tid_in_key(void) { key_attach_thread_id = false; } + bool is_tid_used_in_key(void) const { return key_attach_thread_id; } }; MKLDNNDeviceContextThreadLocals() = default; MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = @@ -580,14 +586,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Remove all entries from the blob map void ResetBlobMap(); - // Set a suffix to be added to key - void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } - const std::string& GetKeySuffix(void) const { return key_suffix_; } - - // Disable adding thread ID to the key - void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; } - bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; } - // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -609,8 +607,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext { std::shared_ptr p_blobmap_; std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; - std::string key_suffix_; // Key identifying current Executor - bool key_attach_thread_id_ = true; }; #endif diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 59a95e34c5478ef0d2f2e7fd5a97a6198d167790..935e856f177f017b565431f1f5a88b9b6f93dc6c 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -432,14 +432,23 @@ inline void AppendKey(std::string* key, const std::vector& dims) { inline void AttachPointerHashToMKLDNNKey(void* ptr, const platform::Place& place) { if (platform::is_cpu_place(place)) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext* dev_ctx = - (platform::MKLDNNDeviceContext*)pool.Get(place); - dev_ctx->SetKeySuffix("E" + - std::to_string(reinterpret_cast(ptr))); - // When NaiveExecutor/Executor is used no info on thread id is needed in a - // key - dev_ctx->DisableThreadInfoInKey(); + // Static vars will remember first executor and its thread + // so both of them need to be processed by the same thread within + // critical section + static std::mutex static_vars_barrier; + static_vars_barrier.lock(); + static auto first_exec = ptr; + static auto first_thread = ThreadIDasStr(); + static_vars_barrier.unlock(); + + if (first_exec != ptr) { + paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix( + "E" + std::to_string(reinterpret_cast(ptr))); + } + // For first thread + if (first_thread == ThreadIDasStr()) { + paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key(); + } } } @@ -450,13 +459,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx, key.reserve(64); using expand_type = int[]; expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; - key += dev_ctx.GetKeySuffix(); + key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix(); return key; } inline std::string ExtendKeyWithThreadInfoIfNeeded( const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) { - return ((dev_ctx.IsThreadIdUsedInKey() == true) && + return ((paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() == + true) && (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default)) ? key + "-t:" + ThreadIDasStr()