未验证 提交 ba610761 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Added clearing oneDNN cache per executor (#32499)

* - Added clearing oneDNN per executor

* - Executor is nt always having FLAGS_use_mkldnn set to true
上级 0dc02dc7
......@@ -72,7 +72,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
ClearMKLDNNCache(place_);
ClearMKLDNNCache(place_, this);
#endif
}
......@@ -169,6 +169,9 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes);
......@@ -294,6 +297,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
if (FLAGS_use_mkldnn) EnableMKLDNN(program);
#ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_);
#endif
bool has_feed_ops =
has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
bool has_fetch_ops =
......@@ -576,7 +582,6 @@ void Executor::EnableMKLDNN(const ProgramDesc& program) {
}
}
}
platform::AttachPointerHashToMKLDNNKey(this, place_);
#else
LOG(WARNING)
<< "'MKLDNN' is not supported, Please re-compile with WITH_MKLDNN option";
......
......@@ -128,7 +128,7 @@ NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
ClearMKLDNNCache(place_);
ClearMKLDNNCache(place_, this);
#endif
}
......
......@@ -411,7 +411,8 @@ void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_);
dev_ctx->ResetBlobMap();
dev_ctx->ResetBlobMap(
paddle::platform::MKLDNNDeviceContext::tls().get_curr_exec());
}
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
......
......@@ -50,7 +50,7 @@ class CacheTester {
platform::CPUPlace place;
onednn_dev_ctx_ =
dynamic_cast<platform::MKLDNNDeviceContext *>(pool.Get(place));
onednn_dev_ctx_->ResetBlobMap();
onednn_dev_ctx_->ResetBlobMap(nullptr);
}
bool Analyze(unsigned short int num_entries) {
......
......@@ -537,6 +537,7 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_exec_items_.reset(new ExecMap());
p_mutex_.reset(new std::mutex());
}
......@@ -560,7 +561,7 @@ MKLDNNDeviceContextThreadLocals::Body::~Body() {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
dev_ctx->ResetBlobMap();
dev_ctx->ResetBlobMap(exec_ptr_);
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
......@@ -607,17 +608,34 @@ mkldnn::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
return cur_stream;
}
void MKLDNNDeviceContext::ResetBlobMap() {
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) {
VLOG(3) << "Clearing DNNL cache.";
// If no specific executor pointer then clear
// everything. For executor pointer then clear only
// objects allocated when using given executor
if (ptr == nullptr) {
p_blobmap_->clear();
} else {
for (auto& v : (*p_exec_items_)[ptr]) {
(v.first)->erase(v.second);
}
p_exec_items_->erase(ptr);
}
} else {
VLOG(3) << "Prevented Clearing DNNL cache.";
block_next_cache_clearing_ = false;
}
}
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
KeyBlob::iterator it) const {
// Take current executor addess from TLS
// and for this executor's items add the one defined with arguments
(*p_exec_items_)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));
}
void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
VLOG(3) << "Next DNNL cache clearing has been blocked.";
......@@ -682,7 +700,11 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
(*pBlob)[name] = data;
auto el =
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map
// to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first);
} else {
blob_it->second = data; // set data to existing blob
}
......
......@@ -673,6 +673,7 @@ class MKLDNNDeviceContextThreadLocals {
mkldnn::stream cur_stream;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
void* exec_ptr_ = nullptr;
Body();
~Body();
......@@ -689,6 +690,8 @@ class MKLDNNDeviceContextThreadLocals {
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
void* get_curr_exec(void) const { return exec_ptr_; }
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
......@@ -724,13 +727,19 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
using ExecMap = std::unordered_map<
void*, std::vector<std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>>>;
explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Get the active engine */
const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
// Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
// Remove all entries from the blob map
void ResetBlobMap();
void ResetBlobMap(void* ptr);
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
......@@ -753,6 +762,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
private:
std::shared_ptr<BlobMap> p_blobmap_;
// Map key is pointer of executor and value is a data(iterator in map) needed
// to erase
std::shared_ptr<ExecMap> p_exec_items_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
};
......
......@@ -135,13 +135,14 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
return mkldnn::memory::desc({dims}, data_type, format);
}
inline void ClearMKLDNNCache(const platform::Place& place) {
inline void ClearMKLDNNCache(const platform::Place& place,
void* ptr = nullptr) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap();
dev_ctx->ResetBlobMap(ptr);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
......@@ -452,6 +453,9 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
}
// Let's register adress of current executor
paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr);
// For first thread
if (first_thread == ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册