未验证 提交 dabaca00 编写于 作者: J Jacek Czaja 提交者: GitHub

Candidate fix to #31992 (#32136)

上级 3822247f
......@@ -600,6 +600,8 @@ class MKLDNNDeviceContextThreadLocals {
// MKL-DNN stream used for execution of primitives (per-thread)
mkldnn::engine cur_engine;
mkldnn::stream cur_stream;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
Body();
~Body();
......@@ -612,6 +614,10 @@ class MKLDNNDeviceContextThreadLocals {
void log_lib_version(void);
const mkldnn::engine& get_engine(void);
mkldnn::stream& get_stream(void);
void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
......@@ -655,14 +661,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Remove all entries from the blob map
void ResetBlobMap();
// Set a suffix to be added to key
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Disable adding thread ID to the key
void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; }
bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; }
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
......@@ -686,8 +684,6 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
bool key_attach_thread_id_ = true;
};
#endif
......
......@@ -439,14 +439,23 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" +
std::to_string(reinterpret_cast<uintptr_t>(ptr)));
// When NaiveExecutor/Executor is used no info on thread id is needed in a
// key
dev_ctx->DisableThreadInfoInKey();
// Static vars will remember first executor and its thread
// so both of them need to be processed by the same thread within
// critical section
static std::mutex static_vars_barrier;
static_vars_barrier.lock();
static auto first_exec = ptr;
static auto first_thread = ThreadIDasStr();
static_vars_barrier.unlock();
if (first_exec != ptr) {
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
}
// For first thread
if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
}
}
}
......@@ -457,13 +466,14 @@ inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
key.reserve(64);
using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += dev_ctx.GetKeySuffix();
key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
return key;
}
inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return ((dev_ctx.IsThreadIdUsedInKey() == true) &&
return ((paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
true) &&
(platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
? key + "-t:" + ThreadIDasStr()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册