diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index e5bfbf4a8f779a4a1baf9f23c894eadd1d1c4902..de007c128d7543c1433426e80abcbd80ee47dee8 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -72,7 +72,7 @@ Executor::~Executor() { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working - ClearMKLDNNCache(place_); + ClearMKLDNNCache(place_, this); #endif } @@ -169,6 +169,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool force_disable_gc, bool keep_kid_scopes) { platform::RecordBlock b(block_id); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); +#ifdef PADDLE_WITH_MKLDNN + platform::AttachPointerHashToMKLDNNKey(this, place_); +#endif auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars, keep_kid_scopes); @@ -294,6 +297,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, const std::string& fetch_holder_name) { platform::RecordBlock b(kProgramId); if (FLAGS_use_mkldnn) EnableMKLDNN(program); +#ifdef PADDLE_WITH_MKLDNN + platform::AttachPointerHashToMKLDNNKey(this, place_); +#endif bool has_feed_ops = has_feed_operators(program.Block(0), *feed_targets, feed_holder_name); bool has_fetch_ops = @@ -576,7 +582,6 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { } } } - platform::AttachPointerHashToMKLDNNKey(this, place_); #else LOG(WARNING) << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index f107321958ba7be4d3ba31bd128f0cbbad694b85..7d55d8c41e3e92349dc9986b3d236db2ebdac01b 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -128,7 +128,7 @@ NaiveExecutor::~NaiveExecutor() { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working - ClearMKLDNNCache(place_); + ClearMKLDNNCache(place_, this); #endif } diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 793fc53d90b768050572a3dd0a080a5d30e959a2..f6cdbb00b50453d4c4ff7fc06ba82aa042dd194a 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -411,7 +411,8 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_); - dev_ctx->ResetBlobMap(); + dev_ctx->ResetBlobMap( + paddle::platform::MKLDNNDeviceContext::tls().get_curr_exec()); } void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc index aafff5248a0244e9090b10f6dc466c93eaa06888..d6cd76b697f5189a60d11a546abed04294f02326 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_caching.cc @@ -50,7 +50,7 @@ class CacheTester { platform::CPUPlace place; onednn_dev_ctx_ = dynamic_cast(pool.Get(place)); - onednn_dev_ctx_->ResetBlobMap(); + onednn_dev_ctx_->ResetBlobMap(nullptr); } bool Analyze(unsigned short int num_entries) { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 50bb64d5574440a9565793e578322f171b6586a1..9a47ac45462ed7080d34404891fb8410a71d3938 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -537,6 +537,7 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) : CPUDeviceContext(place), p_blobmap_() { p_blobmap_.reset(new BlobMap()); + p_exec_items_.reset(new ExecMap()); p_mutex_.reset(new std::mutex()); } @@ -560,7 +561,7 @@ MKLDNNDeviceContextThreadLocals::Body::~Body() { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(cpu_place); - dev_ctx->ResetBlobMap(); + dev_ctx->ResetBlobMap(exec_ptr_); } void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id( @@ -607,17 +608,34 @@ mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) { return cur_stream; } -void MKLDNNDeviceContext::ResetBlobMap() { +void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { std::lock_guard lock(*p_mutex_); if (!block_next_cache_clearing_) { VLOG(3) << "Clearing DNNL cache."; - p_blobmap_->clear(); + // If no specific executor pointer then clear + // everything. For executor pointer then clear only + // objects allocated when using given executor + if (ptr == nullptr) { + p_blobmap_->clear(); + } else { + for (auto& v : (*p_exec_items_)[ptr]) { + (v.first)->erase(v.second); + } + p_exec_items_->erase(ptr); + } } else { VLOG(3) << "Prevented Clearing DNNL cache."; block_next_cache_clearing_ = false; } } +void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t pblob, + KeyBlob::iterator it) const { + // Take current executor addess from TLS + // and for this executor's items add the one defined with arguments + (*p_exec_items_)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it)); +} + void MKLDNNDeviceContext::BlockNextCacheClearing() { std::lock_guard lock(*p_mutex_); VLOG(3) << "Next DNNL cache clearing has been blocked."; @@ -682,7 +700,11 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, // Find Blob via name auto blob_it = pBlob->find(name); if (blob_it == pBlob->end()) { - (*pBlob)[name] = data; + auto el = + pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data; + // Register new element in per executor map + // to have easily erased when executor terminated + LinkEntryWithExecutor(pBlob, el.first); } else { blob_it->second = data; // set data to existing blob } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index f79cb1ab94788126a562764ac6ff7efc4b302d2e..d91e14ec3aa923b81976f953d9673175d5217b21 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -673,6 +673,7 @@ class MKLDNNDeviceContextThreadLocals { mkldnn::stream cur_stream; std::string key_suffix; // Key identifying current Executor bool key_attach_thread_id = true; + void* exec_ptr_ = nullptr; Body(); ~Body(); @@ -689,6 +690,8 @@ class MKLDNNDeviceContextThreadLocals { const std::string& get_key_suffix(void) const { return key_suffix; } void disable_tid_in_key(void) { key_attach_thread_id = false; } bool is_tid_used_in_key(void) const { return key_attach_thread_id; } + void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; } + void* get_curr_exec(void) const { return exec_ptr_; } }; MKLDNNDeviceContextThreadLocals() = default; MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = @@ -724,13 +727,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext { using ShapeBlob = umap_key_string_t; using BlobMap = umap_value_smart_t; + using ExecMap = std::unordered_map< + void*, std::vector, KeyBlob::iterator>>>; + explicit MKLDNNDeviceContext(CPUPlace place); /* \brief Get the active engine */ const mkldnn::engine& GetEngine() const { return tls().get_engine(); } + // Register object to currently used executor's map + void LinkEntryWithExecutor(BlobPtr_t, KeyBlob::iterator) const; + // Remove all entries from the blob map - void ResetBlobMap(); + void ResetBlobMap(void* ptr); // Prevent next ResetBlobMap() void BlockNextCacheClearing(); @@ -753,6 +762,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { private: std::shared_ptr p_blobmap_; + // Map key is pointer of executor and value is a data(iterator in map) needed + // to erase + std::shared_ptr p_exec_items_; std::shared_ptr p_mutex_; bool block_next_cache_clearing_ = false; }; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 35776b9f1e6b88658fcefed015f0dc152a51d8bc..0b683a742c9fd8094e91c54d4f323120bad1eaca 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -135,13 +135,14 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, return mkldnn::memory::desc({dims}, data_type, format); } -inline void ClearMKLDNNCache(const platform::Place& place) { +inline void ClearMKLDNNCache(const platform::Place& place, + void* ptr = nullptr) { // Clear mkl-dnn cache, if (platform::is_cpu_place(place)) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::MKLDNNDeviceContext* dev_ctx = (platform::MKLDNNDeviceContext*)pool.Get(place); - dev_ctx->ResetBlobMap(); + dev_ctx->ResetBlobMap(ptr); platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( paddle::framework::DataLayout::kNCHW); } @@ -452,6 +453,9 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix( "E" + std::to_string(reinterpret_cast(ptr))); } + // Let's register adress of current executor + paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr); + // For first thread if (first_thread == ThreadIDasStr()) { paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();