From ce38bb5341b6904e060a3bf0b934f811ae2bab4b Mon Sep 17 00:00:00 2001 From: Leo Zhao <48052473+LeoZhao-Intel@users.noreply.github.com> Date: Mon, 8 Jul 2019 14:01:45 +0800 Subject: [PATCH] use static variable to do cache instead of thread local in thread frequent switching case (#18428) --- .../fluid/framework/transfer_scope_cache.cc | 49 +++++++++++++++++++ .../tests/api/analyzer_bert_tester.cc | 36 +++++++++++--- 2 files changed, 79 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/transfer_scope_cache.cc b/paddle/fluid/framework/transfer_scope_cache.cc index 2b138280fb5..e1326f88961 100644 --- a/paddle/fluid/framework/transfer_scope_cache.cc +++ b/paddle/fluid/framework/transfer_scope_cache.cc @@ -17,12 +17,61 @@ namespace paddle { namespace framework { +#ifdef PADDLE_WITH_MKLDNN +using transfer_data_cache_map = std::unordered_map; +using transfer_scope_cache_map = std::unordered_set; +static std::unordered_map + static_transfer_data_caches; +static std::unordered_map + static_transfer_scope_caches; +#endif + std::unordered_map& global_transfer_data_cache() { +#ifdef PADDLE_WITH_MKLDNN + size_t sid = platform::get_cur_mkldnn_session_id(); + + // if there is specific mkldnn tid setting from user. + if (sid != platform::kMKLDNNSessionID_Default) { + sid = std::hash()(std::this_thread::get_id()); + + static std::mutex acquire_barrier; + std::lock_guard block_until_finish_this_job(acquire_barrier); + + auto map_it = static_transfer_data_caches.find(sid); + if (map_it == static_transfer_data_caches.end()) { + auto* x = new transfer_data_cache_map; + static_transfer_data_caches[sid] = x; + return *x; + } else { + return *static_transfer_data_caches[sid]; + } + } +#endif thread_local auto* x = new std::unordered_map; return *x; } std::unordered_set& global_transfer_scope_cache() { +#ifdef PADDLE_WITH_MKLDNN + size_t sid = platform::get_cur_mkldnn_session_id(); + + // if there is specific mkldnn session id setting from user. + if (sid != platform::kMKLDNNSessionID_Default) { + sid = std::hash()(std::this_thread::get_id()); + + static std::mutex acquire_barrier; + std::lock_guard block_until_finish_this_job(acquire_barrier); + + auto map_it = static_transfer_scope_caches.find(sid); + if (map_it == static_transfer_scope_caches.end()) { + auto* x = new transfer_scope_cache_map; + static_transfer_scope_caches[sid] = x; + return *x; + } else { + return *static_transfer_scope_caches[sid]; + } + } +#endif thread_local auto* x = new std::unordered_set; return *x; } diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index f679e122182..406c028a9fb 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -230,7 +230,7 @@ TEST(Analyzer_bert, compare_determine) { inputs); } -TEST(Analyzer_bert, transfer_scope_cache) { +void verify_transfer_scope_cache(bool is_static = false) { AnalysisConfig config; SetConfig(&config); @@ -251,6 +251,11 @@ TEST(Analyzer_bert, transfer_scope_cache) { threads.emplace_back([&, i]() { std::getline(fin, line); ParseLine(line, &input); +#ifdef PADDLE_WITH_MKLDNN + // Use static method to handle transfer_scope_cache() + // TODO(intel) explicit session id setting will be deprecated. + if (is_static) platform::set_cur_mkldnn_session_id(1); +#endif predictor->Run(input, &output, FLAGS_batch_size); global_transfer_scope_cache.insert( &paddle::framework::global_transfer_scope_cache()); @@ -261,12 +266,31 @@ TEST(Analyzer_bert, transfer_scope_cache) { threads.clear(); std::vector().swap(input); } - // Since paddle::framework::global_transfer_scope_cache() and - // paddle::framework::global_transfer_data_cache() are thread_local, - // their pointer should be different among different thread id. - PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num); - PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num); +#ifdef PADDLE_WITH_MKLDNN + if (is_static) { + // Use static method to do transfer_scope_cache() instead of thread_local + // so paddle::framework::global_transfer_data_cache() should be 1 + PADDLE_ENFORCE(global_transfer_scope_cache.size(), 1); + PADDLE_ENFORCE(global_transfer_data_cache.size(), 1); + } else { +#endif + // Since paddle::framework::global_transfer_scope_cache() and + // paddle::framework::global_transfer_data_cache() are thread_local, + // their pointer should be different among different thread id. + PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num); + PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num); +#ifdef PADDLE_WITH_MKLDNN + } +#endif } +TEST(Analyzer_bert, threadlocal_transfer_scope_cache) { + verify_transfer_scope_cache(); +} +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_bert, static_transfer_scope_cache) { + verify_transfer_scope_cache(true); +} +#endif } // namespace inference } // namespace paddle -- GitLab