diff --git a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc index 17c670a68cc9cbcfd74ff3541fa1f3bc07200062..ce9ad6ff125011cbb03311ceb1521ff9c80d375f 100644 --- a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc @@ -172,5 +172,61 @@ TEST(Analyzer_MM_DNN, compare_determine) { input_slots_all); } +#ifdef PADDLE_WITH_MKLDNN +void TestMkldnnCacheClear(int mkldnn_input_shape_cache_capacity) { + AnalysisConfig config; + SetConfig(&config); + config.EnableMKLDNN(); + // TODO(luotao): explicit following settings will be deprecated after enhance + // config.EnableMKLDNN() interface. + if (mkldnn_input_shape_cache_capacity > 0) { + platform::set_cur_mkldnn_session_id( + platform::kMKLDNNSessionID_CacheClearing); + platform::set_cur_input_shape_cache_capacity( + mkldnn_input_shape_cache_capacity); + } + + std::vector input, output; + auto predictor = CreatePaddlePredictor(config); + + int sample_num = 10; + DataRecord data(FLAGS_infer_data, FLAGS_batch_size); + + auto &pool = platform::DeviceContextPool::Instance(); + auto *dev_ctx = dynamic_cast( + pool.Get(platform::CPUPlace())); + for (int i = 0; i < sample_num; i++) { + PrepareInputs(&input, &data, FLAGS_batch_size); + if (mkldnn_input_shape_cache_capacity > 0) { + std::stringstream ss; + for (size_t i = 0; i < input.size(); i++) { + for (size_t j = 0; j < input[i].shape.size(); ++j) { + ss << input[i].shape[j] << "-"; + } + } + // TODO(luotao): explicit following settings will be deprecated after + // enhance config.EnableMKLDNN() interface. + platform::set_cur_input_shape_str(ss.str()); + } + predictor->Run(input, &output, 1); + } + if (mkldnn_input_shape_cache_capacity > 0) { + PADDLE_ENFORCE_EQ(dev_ctx->GetShapeBlobSize(), + mkldnn_input_shape_cache_capacity); + } else { + PADDLE_ENFORCE_EQ(dev_ctx->GetShapeBlobSize(), 1UL); + } + dev_ctx->ResetBlobMap(); +} + +TEST(Analyzer_MM_DNN, mkldnn_cache_clear) { + // 0 means do not use cache clear strategy. + TestMkldnnCacheClear(0); + // 4 means use cache clear strategy, and the + // mkldnn_input_shape_cache_capacity is 4. + TestMkldnnCacheClear(4); +} +#endif + } // namespace inference } // namespace paddle diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 0dabe2ed3d92c141f03b07a26139a7997b01e478..87b82ec5e390aab0b7c63223a8cb35d26c495fed 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -407,6 +407,9 @@ thread_local size_t cur_mkldnn_session_id = kMKLDNNSessionID_Default; // - For fixed-shape, it's a null string in default. // - For dynamic-shape, it's user specific. thread_local std::string cur_input_shape_str = ""; +// the cache capacity of different input shapes for MKLDNN. +// Default 1 means fixed input shape, not dynamic shape. +thread_local int cur_input_shape_cache_capacity = 1; } // namespace void set_cur_mkldnn_session_id(size_t sid) { cur_mkldnn_session_id = sid; } @@ -414,37 +417,58 @@ size_t get_cur_mkldnn_session_id(void) { return cur_mkldnn_session_id; } void set_cur_input_shape_str(std::string input_shape_str) { cur_input_shape_str = input_shape_str; } -std::string get_cur_input_shape_str(void) { return cur_input_shape_str; } +void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity) { + cur_input_shape_cache_capacity = input_shape_cache_capacity; +} void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } +size_t MKLDNNDeviceContext::GetShapeBlobSize() const { + std::lock_guard lock(*p_mutex_); + BlobMap* pMap = p_blobmap_.get(); + auto map_it = pMap->find(cur_mkldnn_session_id); + if (map_it == pMap->end()) { + LOG(FATAL) << "MKLDNNDeviceContext don't find cur_mkldnn_session_id : " + << cur_mkldnn_session_id; + } + return map_it->second->size(); +} + void MKLDNNDeviceContext::SetBlob(const std::string& name, std::shared_ptr data) const { BlobMap* pMap = p_blobmap_.get(); std::shared_ptr sBlob = nullptr; std::shared_ptr pBlob = nullptr; - int tid = platform::get_cur_mkldnn_session_id(); + int sid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); - // Find ShapeBlob for current thread - auto map_it = pMap->find(tid); + // Find ShapeBlob for current mkldnn session id. + auto map_it = pMap->find(sid); if (map_it == pMap->end()) { // 1st time to set blob in current thread sBlob = std::shared_ptr(new ShapeBlob()); - (*pMap)[tid] = sBlob; - VLOG(2) << "SetBlob: tid=" << tid << ", add new tid\n"; + (*pMap)[sid] = sBlob; + VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n"; } else { sBlob = map_it->second; } // Find KeyBlob for current input shape - std::string cur_input_shape_str = platform::get_cur_input_shape_str(); auto key_it = sBlob->find(cur_input_shape_str); if (key_it == sBlob->end()) { + // In cache clearing mode, cur_input_shape_cache_capacity defines + // max pblob capacity + if ((sid == kMKLDNNSessionID_CacheClearing) && + (sBlob->size() >= + static_cast(cur_input_shape_cache_capacity))) { + VLOG(2) << "sid=" << sid + << ", remove all blobs of shape: " << sBlob->begin()->first; + sBlob->erase(sBlob->begin()->first); + } pBlob = std::shared_ptr(new KeyBlob()); (*sBlob)[cur_input_shape_str] = pBlob; } else { @@ -458,7 +482,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, } else { blob_it->second = data; // set data to existing blob } - VLOG(2) << "SetBlob: tid=" << tid << ", add blob=" << name << "\n"; + VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n"; // lock will be automatically released when out of scope return; } @@ -469,23 +493,22 @@ std::shared_ptr MKLDNNDeviceContext::GetBlob( std::shared_ptr sBlob = nullptr; std::shared_ptr pBlob = nullptr; - int tid = platform::get_cur_mkldnn_session_id(); + int sid = platform::get_cur_mkldnn_session_id(); std::lock_guard lock(*p_mutex_); - // Find ShapeBlob for current thread firstly - auto map_it = pMap->find(tid); + // Find ShapeBlob for current mkldnn session id firstly + auto map_it = pMap->find(sid); if (map_it == pMap->end()) { - VLOG(2) << "GetBlob: tid=" << tid << ", miss tid\n"; + VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n"; return nullptr; } - std::string cur_input_shape_str = platform::get_cur_input_shape_str(); sBlob = map_it->second; // Find KeyBlob for current input shape secondly auto sBlob_it = sBlob->find(cur_input_shape_str); if (sBlob_it == sBlob->end()) { - VLOG(2) << "GetBlob: tid=" << cur_input_shape_str + VLOG(2) << "GetBlob: sid=" << cur_input_shape_str << ", miss input_shape_str\n"; return nullptr; } @@ -495,11 +518,11 @@ std::shared_ptr MKLDNNDeviceContext::GetBlob( auto key_it = pBlob->find(name); if (key_it == pBlob->end()) { - VLOG(2) << "GetBlob tid=" << tid << ", miss blob=" << name << "\n"; + VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; return nullptr; } - VLOG(2) << "GetBlob tid=" << tid << ", get blob=" << name << "\n"; + VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n"; // lock will be automatically released when out of scope return key_it->second; } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 1aef2bb45dd7963ace59d2f00933435ca2b130c7..a17a0bdfb9aea384371acf631b778f0ec8183a87 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -396,7 +396,7 @@ constexpr size_t kMKLDNNSessionID_CacheClearing = -1; void set_cur_mkldnn_session_id(size_t); size_t get_cur_mkldnn_session_id(void); void set_cur_input_shape_str(std::string input_shape_str); -std::string get_cur_input_shape_str(void); +void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity); class MKLDNNDeviceContext : public CPUDeviceContext { public: @@ -408,6 +408,9 @@ class MKLDNNDeviceContext : public CPUDeviceContext { // Remove all entries from the blob map void ResetBlobMap() const; + // Get the ShapeBlob size in cur_mkldnn_session_id. + size_t GetShapeBlobSize() const; + // Set data to blob (i.e. name/data pair). Create blob if not existing void SetBlob(const std::string& name, std::shared_ptr data) const;