未验证 提交 67911261 编写于 作者: J Jacek Czaja 提交者: GitHub

[cherry-pick] Fix to 31992 for 2.0 (#32163)

* - Candidate fix to #31992

- Fix to #31992 for 2.0
上级 5c7ad3bc
......@@ -528,6 +528,8 @@ class MKLDNNDeviceContextThreadLocals {
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
paddle::framework::DataLayout cur_paddle_data_layout;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
Body();
void set_cur_mkldnn_session_id(size_t sid);
......@@ -537,6 +539,10 @@ class MKLDNNDeviceContextThreadLocals {
void set_cur_paddle_data_layout(framework::DataLayout dl);
framework::DataLayout get_cur_paddle_data_layout(void);
void log_lib_version(void);
void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
......@@ -580,14 +586,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map
void ResetBlobMap();
// Set a suffix to be added to key
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Disable adding thread ID to the key
void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }
bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
......@@ -609,8 +607,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
bool key_attach_thread_id_ = true;
};
#endif
......
......@@ -432,14 +432,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" +
std::to_string(reinterpret_cast<uintptr_t>(ptr)));
// When NaiveExecutor/Executor is used no info on thread id is needed in a
// key
dev_ctx->DisableThreadInfoInKey();
// Static vars will remember first executor and its thread
// so both of them need to be processed by the same thread within
// critical section
static std::mutex static_vars_barrier;
static_vars_barrier.lock();
static auto first_exec = ptr;
static auto first_thread = ThreadIDasStr();
static_vars_barrier.unlock();
if (first_exec != ptr) {
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
}
// For first thread
if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
}
}
}
......@@ -450,13 +459,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
key.reserve(64);
using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += dev_ctx.GetKeySuffix();
key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
return key;
}
inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return ((dev_ctx.IsThreadIdUsedInKey() == true) &&
return ((paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
true) &&
(platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
? key + "-t:" + ThreadIDasStr()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册