未验证 提交 67911261 编写于 作者: J Jacek Czaja 提交者: GitHub

[cherry-pick] Fix to 31992 for 2.0 (#32163)

* - Candidate fix to #31992

- Fix to #31992 for 2.0
上级 5c7ad3bc
...@@ -528,6 +528,8 @@ class MKLDNNDeviceContextThreadLocals { ...@@ -528,6 +528,8 @@ class MKLDNNDeviceContextThreadLocals {
// Recently registered data_format. This is needed to // Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN // know for converting MKL-DNN Tensor to non MKL-DNN
paddle::framework::DataLayout cur_paddle_data_layout; paddle::framework::DataLayout cur_paddle_data_layout;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
Body(); Body();
void set_cur_mkldnn_session_id(size_t sid); void set_cur_mkldnn_session_id(size_t sid);
...@@ -537,6 +539,10 @@ class MKLDNNDeviceContextThreadLocals { ...@@ -537,6 +539,10 @@ class MKLDNNDeviceContextThreadLocals {
void set_cur_paddle_data_layout(framework::DataLayout dl); void set_cur_paddle_data_layout(framework::DataLayout dl);
framework::DataLayout get_cur_paddle_data_layout(void); framework::DataLayout get_cur_paddle_data_layout(void);
void log_lib_version(void); void log_lib_version(void);
void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
}; };
MKLDNNDeviceContextThreadLocals() = default; MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
...@@ -580,14 +586,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -580,14 +586,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map // Remove all entries from the blob map
void ResetBlobMap(); void ResetBlobMap();
// Set a suffix to be added to key
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Disable adding thread ID to the key
void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }
bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }
// Prevent next ResetBlobMap() // Prevent next ResetBlobMap()
void BlockNextCacheClearing(); void BlockNextCacheClearing();
...@@ -609,8 +607,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -609,8 +607,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<BlobMap> p_blobmap_; std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_; std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false; bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
bool key_attach_thread_id_ = true;
}; };
#endif #endif
......
...@@ -432,14 +432,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) { ...@@ -432,14 +432,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
inline void AttachPointerHashToMKLDNNKey(void* ptr, inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) { const platform::Place& place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); // Static vars will remember first executor and its thread
platform::MKLDNNDeviceContext* dev_ctx = // so both of them need to be processed by the same thread within
(platform::MKLDNNDeviceContext*)pool.Get(place); // critical section
dev_ctx->SetKeySuffix("E" + static std::mutex static_vars_barrier;
std::to_string(reinterpret_cast<uintptr_t>(ptr))); static_vars_barrier.lock();
// When NaiveExecutor/Executor is used no info on thread id is needed in a static auto first_exec = ptr;
// key static auto first_thread = ThreadIDasStr();
dev_ctx->DisableThreadInfoInKey(); static_vars_barrier.unlock();
if (first_exec != ptr) {
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
}
// For first thread
if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
}
} }
} }
...@@ -450,13 +459,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx, ...@@ -450,13 +459,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
key.reserve(64); key.reserve(64);
using expand_type = int[]; using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...}; expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += dev_ctx.GetKeySuffix(); key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
return key; return key;
} }
inline std::string ExtendKeyWithThreadInfoIfNeeded( inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) { const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return ((dev_ctx.IsThreadIdUsedInKey() == true) && return ((paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
true) &&
(platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() == (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default)) platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
? key + "-t:" + ThreadIDasStr() ? key + "-t:" + ThreadIDasStr()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册