提交 ce38bb53 编写于 作者: L Leo Zhao 提交者: Tao Luo

use static variable to do cache instead of thread local in thread frequent switching case (#18428)

上级 160ddc98
...@@ -17,12 +17,61 @@ ...@@ -17,12 +17,61 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_MKLDNN
using transfer_data_cache_map = std::unordered_map<size_t, Scope*>;
using transfer_scope_cache_map = std::unordered_set<Scope*>;
static std::unordered_map<size_t, transfer_data_cache_map*>
static_transfer_data_caches;
static std::unordered_map<size_t, transfer_scope_cache_map*>
static_transfer_scope_caches;
#endif
std::unordered_map<size_t, Scope*>& global_transfer_data_cache() { std::unordered_map<size_t, Scope*>& global_transfer_data_cache() {
#ifdef PADDLE_WITH_MKLDNN
size_t sid = platform::get_cur_mkldnn_session_id();
// if there is specific mkldnn tid setting from user.
if (sid != platform::kMKLDNNSessionID_Default) {
sid = std::hash<std::thread::id>()(std::this_thread::get_id());
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_until_finish_this_job(acquire_barrier);
auto map_it = static_transfer_data_caches.find(sid);
if (map_it == static_transfer_data_caches.end()) {
auto* x = new transfer_data_cache_map;
static_transfer_data_caches[sid] = x;
return *x;
} else {
return *static_transfer_data_caches[sid];
}
}
#endif
thread_local auto* x = new std::unordered_map<size_t, Scope*>; thread_local auto* x = new std::unordered_map<size_t, Scope*>;
return *x; return *x;
} }
std::unordered_set<Scope*>& global_transfer_scope_cache() { std::unordered_set<Scope*>& global_transfer_scope_cache() {
#ifdef PADDLE_WITH_MKLDNN
size_t sid = platform::get_cur_mkldnn_session_id();
// if there is specific mkldnn session id setting from user.
if (sid != platform::kMKLDNNSessionID_Default) {
sid = std::hash<std::thread::id>()(std::this_thread::get_id());
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_until_finish_this_job(acquire_barrier);
auto map_it = static_transfer_scope_caches.find(sid);
if (map_it == static_transfer_scope_caches.end()) {
auto* x = new transfer_scope_cache_map;
static_transfer_scope_caches[sid] = x;
return *x;
} else {
return *static_transfer_scope_caches[sid];
}
}
#endif
thread_local auto* x = new std::unordered_set<Scope*>; thread_local auto* x = new std::unordered_set<Scope*>;
return *x; return *x;
} }
......
...@@ -230,7 +230,7 @@ TEST(Analyzer_bert, compare_determine) { ...@@ -230,7 +230,7 @@ TEST(Analyzer_bert, compare_determine) {
inputs); inputs);
} }
TEST(Analyzer_bert, transfer_scope_cache) { void verify_transfer_scope_cache(bool is_static = false) {
AnalysisConfig config; AnalysisConfig config;
SetConfig(&config); SetConfig(&config);
...@@ -251,6 +251,11 @@ TEST(Analyzer_bert, transfer_scope_cache) { ...@@ -251,6 +251,11 @@ TEST(Analyzer_bert, transfer_scope_cache) {
threads.emplace_back([&, i]() { threads.emplace_back([&, i]() {
std::getline(fin, line); std::getline(fin, line);
ParseLine(line, &input); ParseLine(line, &input);
#ifdef PADDLE_WITH_MKLDNN
// Use static method to handle transfer_scope_cache()
// TODO(intel) explicit session id setting will be deprecated.
if (is_static) platform::set_cur_mkldnn_session_id(1);
#endif
predictor->Run(input, &output, FLAGS_batch_size); predictor->Run(input, &output, FLAGS_batch_size);
global_transfer_scope_cache.insert( global_transfer_scope_cache.insert(
&paddle::framework::global_transfer_scope_cache()); &paddle::framework::global_transfer_scope_cache());
...@@ -261,12 +266,31 @@ TEST(Analyzer_bert, transfer_scope_cache) { ...@@ -261,12 +266,31 @@ TEST(Analyzer_bert, transfer_scope_cache) {
threads.clear(); threads.clear();
std::vector<PaddleTensor>().swap(input); std::vector<PaddleTensor>().swap(input);
} }
#ifdef PADDLE_WITH_MKLDNN
if (is_static) {
// Use static method to do transfer_scope_cache() instead of thread_local
// so paddle::framework::global_transfer_data_cache() should be 1
PADDLE_ENFORCE(global_transfer_scope_cache.size(), 1);
PADDLE_ENFORCE(global_transfer_data_cache.size(), 1);
} else {
#endif
// Since paddle::framework::global_transfer_scope_cache() and // Since paddle::framework::global_transfer_scope_cache() and
// paddle::framework::global_transfer_data_cache() are thread_local, // paddle::framework::global_transfer_data_cache() are thread_local,
// their pointer should be different among different thread id. // their pointer should be different among different thread id.
PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num); PADDLE_ENFORCE(global_transfer_scope_cache.size(), threads_num);
PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num); PADDLE_ENFORCE(global_transfer_data_cache.size(), threads_num);
#ifdef PADDLE_WITH_MKLDNN
}
#endif
} }
TEST(Analyzer_bert, threadlocal_transfer_scope_cache) {
verify_transfer_scope_cache();
}
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_bert, static_transfer_scope_cache) {
verify_transfer_scope_cache(true);
}
#endif
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册