未验证 提交 ba610761 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Added clearing oneDNN cache per executor (#32499)

* - Added clearing oneDNN per executor

* - Executor is nt always having FLAGS_use_mkldnn set to true
上级 0dc02dc7
...@@ -72,7 +72,7 @@ Executor::~Executor() { ...@@ -72,7 +72,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
ClearMKLDNNCache(place_); ClearMKLDNNCache(place_, this);
#endif #endif
} }
...@@ -169,6 +169,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -169,6 +169,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc, bool keep_kid_scopes) { bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars, RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes); keep_kid_scopes);
...@@ -294,6 +297,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -294,6 +297,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
if (FLAGS_use_mkldnn) EnableMKLDNN(program); if (FLAGS_use_mkldnn) EnableMKLDNN(program);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
bool has_feed_ops = bool has_feed_ops =
has_feed_operators(program.Block(0), *feed_targets, feed_holder_name); has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
bool has_fetch_ops = bool has_fetch_ops =
...@@ -576,7 +582,6 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) { ...@@ -576,7 +582,6 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
} }
} }
} }
platform::AttachPointerHashToMKLDNNKey(this, place_);
#else #else
LOG(WARNING) LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option"; << "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
......
...@@ -128,7 +128,7 @@ NaiveExecutor::~NaiveExecutor() { ...@@ -128,7 +128,7 @@ NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
ClearMKLDNNCache(place_); ClearMKLDNNCache(place_, this);
#endif #endif
} }
......
...@@ -411,7 +411,8 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { ...@@ -411,7 +411,8 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_); (platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_);
dev_ctx->ResetBlobMap(); dev_ctx->ResetBlobMap(
paddle::platform::MKLDNNDeviceContext::tls().get_curr_exec());
} }
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
......
...@@ -50,7 +50,7 @@ class CacheTester { ...@@ -50,7 +50,7 @@ class CacheTester {
platform::CPUPlace place; platform::CPUPlace place;
onednn_dev_ctx_ = onednn_dev_ctx_ =
dynamic_cast<platform::MKLDNNDeviceContext *>(pool.Get(place)); dynamic_cast<platform::MKLDNNDeviceContext *>(pool.Get(place));
onednn_dev_ctx_->ResetBlobMap(); onednn_dev_ctx_->ResetBlobMap(nullptr);
} }
bool Analyze(unsigned short int num_entries) { bool Analyze(unsigned short int num_entries) {
......
...@@ -537,6 +537,7 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; } ...@@ -537,6 +537,7 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), p_blobmap_() { : CPUDeviceContext(place), p_blobmap_() {
p_blobmap_.reset(new BlobMap()); p_blobmap_.reset(new BlobMap());
p_exec_items_.reset(new ExecMap());
p_mutex_.reset(new std::mutex()); p_mutex_.reset(new std::mutex());
} }
...@@ -560,7 +561,7 @@ MKLDNNDeviceContextThreadLocals::Body::~Body() { ...@@ -560,7 +561,7 @@ MKLDNNDeviceContextThreadLocals::Body::~Body() {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(cpu_place); (platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
dev_ctx->ResetBlobMap(); dev_ctx->ResetBlobMap(exec_ptr_);
} }
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id( void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
...@@ -607,17 +608,34 @@ mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) { ...@@ -607,17 +608,34 @@ mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
return cur_stream; return cur_stream;
} }
void MKLDNNDeviceContext::ResetBlobMap() { void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_); std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) { if (!block_next_cache_clearing_) {
VLOG(3) << "Clearing DNNL cache."; VLOG(3) << "Clearing DNNL cache.";
p_blobmap_->clear(); // If no specific executor pointer then clear
// everything. For executor pointer then clear only
// objects allocated when using given executor
if (ptr == nullptr) {
p_blobmap_->clear();
} else {
for (auto& v : (*p_exec_items_)[ptr]) {
(v.first)->erase(v.second);
}
p_exec_items_->erase(ptr);
}
} else { } else {
VLOG(3) << "Prevented Clearing DNNL cache."; VLOG(3) << "Prevented Clearing DNNL cache.";
block_next_cache_clearing_ = false; block_next_cache_clearing_ = false;
} }
} }
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
KeyBlob::iterator it) const {
// Take current executor addess from TLS
// and for this executor's items add the one defined with arguments
(*p_exec_items_)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));
}
void MKLDNNDeviceContext::BlockNextCacheClearing() { void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_); std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
VLOG(3) << "Next DNNL cache clearing has been blocked."; VLOG(3) << "Next DNNL cache clearing has been blocked.";
...@@ -682,7 +700,11 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -682,7 +700,11 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
// Find Blob via name // Find Blob via name
auto blob_it = pBlob->find(name); auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) { if (blob_it == pBlob->end()) {
(*pBlob)[name] = data; auto el =
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map
// to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first);
} else { } else {
blob_it->second = data; // set data to existing blob blob_it->second = data; // set data to existing blob
} }
......
...@@ -673,6 +673,7 @@ class MKLDNNDeviceContextThreadLocals { ...@@ -673,6 +673,7 @@ class MKLDNNDeviceContextThreadLocals {
mkldnn::stream cur_stream; mkldnn::stream cur_stream;
std::string key_suffix; // Key identifying current Executor std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true; bool key_attach_thread_id = true;
void* exec_ptr_ = nullptr;
Body(); Body();
~Body(); ~Body();
...@@ -689,6 +690,8 @@ class MKLDNNDeviceContextThreadLocals { ...@@ -689,6 +690,8 @@ class MKLDNNDeviceContextThreadLocals {
const std::string& get_key_suffix(void) const { return key_suffix; } const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; } void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; } bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
void* get_curr_exec(void) const { return exec_ptr_; }
}; };
MKLDNNDeviceContextThreadLocals() = default; MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) = MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
...@@ -724,13 +727,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -724,13 +727,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
using ShapeBlob = umap_key_string_t<KeyBlob>; using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>; using BlobMap = umap_value_smart_t<int, ShapeBlob>;
using ExecMap = std::unordered_map<
void*, std::vector<std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>>>;
explicit MKLDNNDeviceContext(CPUPlace place); explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Get the active engine */ /* \brief Get the active engine */
const mkldnn::engine& GetEngine() const { return tls().get_engine(); } const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
// Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
// Remove all entries from the blob map // Remove all entries from the blob map
void ResetBlobMap(); void ResetBlobMap(void* ptr);
// Prevent next ResetBlobMap() // Prevent next ResetBlobMap()
void BlockNextCacheClearing(); void BlockNextCacheClearing();
...@@ -753,6 +762,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -753,6 +762,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
private: private:
std::shared_ptr<BlobMap> p_blobmap_; std::shared_ptr<BlobMap> p_blobmap_;
// Map key is pointer of executor and value is a data(iterator in map) needed
// to erase
std::shared_ptr<ExecMap> p_exec_items_;
std::shared_ptr<std::mutex> p_mutex_; std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false; bool block_next_cache_clearing_ = false;
}; };
......
...@@ -135,13 +135,14 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims, ...@@ -135,13 +135,14 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
return mkldnn::memory::desc({dims}, data_type, format); return mkldnn::memory::desc({dims}, data_type, format);
} }
inline void ClearMKLDNNCache(const platform::Place& place) { inline void ClearMKLDNNCache(const platform::Place& place,
void* ptr = nullptr) {
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place); (platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap(); dev_ctx->ResetBlobMap(ptr);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW); paddle::framework::DataLayout::kNCHW);
} }
...@@ -452,6 +453,9 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, ...@@ -452,6 +453,9 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix( paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr))); "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
} }
// Let's register adress of current executor
paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);
// For first thread // For first thread
if (first_thread == ThreadIDasStr()) { if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key(); paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册