diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 88829d7a207bdd642ae0af415d69362195433d4b..0dabe2ed3d92c141f03b07a26139a7997b01e478 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -403,42 +403,62 @@ MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) namespace { // Current mkldnn session id. thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; -} +// Current data input shape string. +// - For fixed-shape, it's a null string in default. +// - For dynamic-shape, it's user specific. +thread_local std::string cur_input_shape_str = ""; +} // namespace void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } +void set_cur_input_shape_str(std::string input_shape_str) { + cur_input_shape_str = input_shape_str; +} +std::string get_cur_input_shape_str(void) { return cur_input_shape_str; } void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { BlobMap* pMap = p_blobmap_.get(); + std::shared_ptr sBlob = nullptr; std::shared_ptr pBlob = nullptr; int tid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); - // Find KeyBlob for current thread + // Find ShapeBlob for current thread auto map_it = pMap->find(tid); if (map_it == pMap->end()) { // 1st time to set blob in current thread - pBlob = std::shared_ptr(new KeyBlob()); - (*pMap)[tid] = pBlob; + sBlob = std::shared_ptr(new ShapeBlob()); + (*pMap)[tid] = sBlob; + VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n"; } else { - pBlob = map_it->second; + sBlob = map_it->second; } - // Find Key in found (or newly created) KeyBlob - auto key_it = pBlob->find(name); + // Find KeyBlob for current input shape + std::string cur_input_shape_str = platform::get_cur_input_shape_str(); + auto key_it = sBlob->find(cur_input_shape_str); - if (key_it == pBlob->end()) { - (*pBlob)[name] = data; // create new blob + if (key_it == sBlob->end()) { + pBlob = std::shared_ptr(new KeyBlob()); + (*sBlob)[cur_input_shape_str] = pBlob; } else { - key_it->second = data; // set data to existing blob + pBlob = key_it->second; } + // Find Blob via name + auto blob_it = pBlob->find(name); + if (blob_it == pBlob->end()) { + (*pBlob)[name] = data; + } else { + blob_it->second = data; // set data to existing blob + } + VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n"; // lock will be automatically released when out of scope return; } @@ -446,22 +466,40 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr MKLDNNDeviceContext::GetBlob( const std::string& name) const { BlobMap* pMap = p_blobmap_.get(); + std::shared_ptr sBlob = nullptr; std::shared_ptr pBlob = nullptr; int tid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); - // Find KeyBlob for current thread firstly + // Find ShapeBlob for current thread firstly auto map_it = pMap->find(tid); - if (map_it == pMap->end()) return nullptr; - pBlob = map_it->second; + if (map_it == pMap->end()) { + VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n"; + return nullptr; + } + std::string cur_input_shape_str = platform::get_cur_input_shape_str(); + sBlob = map_it->second; + + // Find KeyBlob for current input shape secondly + auto sBlob_it = sBlob->find(cur_input_shape_str); + if (sBlob_it == sBlob->end()) { + VLOG(2) << "GetBlob: tid=" << cur_input_shape_str + << ", miss input_shape_str\n"; + return nullptr; + } + pBlob = sBlob_it->second; // Find Blob via name auto key_it = pBlob->find(name); - if (key_it == pBlob->end()) return nullptr; + if (key_it == pBlob->end()) { + VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n"; + return nullptr; + } + VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n"; // lock will be automatically released when out of scope return key_it->second; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 91dc8a6887a5812b542407656432d2673e3fdb62..1aef2bb45dd7963ace59d2f00933435ca2b130c7 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -378,8 +378,15 @@ struct DefaultDeviceContextType { #endif #ifdef PADDLE_WITH_MKLDNN +// Following three maps are used to cache MKLDNN primitives. +// There relations are: +// - BlobMap = Map +// - ShapeBlob = Map +// - KeyBlob = Map +// Where: using KeyBlob = std::unordered_map>; -using BlobMap = std::unordered_map>; +using ShapeBlob = std::unordered_map>; +using BlobMap = std::unordered_map>; // default mkldnn session id constexpr size_t kMKLDNNSessionID_Default = 0; @@ -388,6 +395,8 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1; void set_cur_mkldnn_session_id(size_t); size_t get_cur_mkldnn_session_id(void); +void set_cur_input_shape_str(std::string input_shape_str); +std::string get_cur_input_shape_str(void); class MKLDNNDeviceContext : public CPUDeviceContext { public: