From a5c4b463c962bed48fba89d459adf82f4899d6c3 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 22 Nov 2018 18:37:33 +0800 Subject: [PATCH] add SetMKLDNNThreadId api --- paddle/fluid/inference/api/analysis_predictor.cc | 8 ++++++++ paddle/fluid/inference/api/analysis_predictor.h | 2 ++ paddle/fluid/inference/api/paddle_analysis_config.h | 2 +- paddle/fluid/inference/tests/api/tester_helper.h | 9 ++++++--- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 9162ccefd..4633a75e5 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -159,6 +159,14 @@ bool AnalysisPredictor::PrepareExecutor() { return true; } +void AnalysisPredictor::SetMKLDNNThreadId(int tid) { +#ifdef PADDLE_WITH_MKLDNN + platform::set_cur_thread_id(tid); +#else + LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN"; +#endif +} + bool AnalysisPredictor::Run(const std::vector &inputs, std::vector *output_data, int batch_size) { diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index cf81b7db7..9191970a3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -69,6 +69,8 @@ class AnalysisPredictor : public PaddlePredictor { framework::Scope *scope() { return scope_.get(); } framework::ProgramDesc &program() { return *inference_program_; } + void SetMKLDNNThreadId(int tid); + protected: bool PrepareProgram(const std::shared_ptr &program); bool PrepareScope(const std::shared_ptr &parent_scope); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 2ac736df7..a09bd1cac 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -51,9 +51,9 @@ struct AnalysisConfig : public NativeConfig { int max_batch_size = 1); bool use_tensorrt() const { return use_tensorrt_; } + void EnableMKLDNN(); // NOTE this is just for internal development, please not use it. // NOT stable yet. - void EnableMKLDNN(); bool use_mkldnn() const { return use_mkldnn_; } friend class ::paddle::AnalysisPredictor; diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index fdadd5904..72703bc80 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -216,13 +216,16 @@ void TestMultiThreadPrediction( size_t total_time{0}; for (int tid = 0; tid < num_threads; ++tid) { threads.emplace_back([&, tid]() { -#ifdef PADDLE_WITH_MKLDNN - platform::set_cur_thread_id(static_cast(tid) + 1); -#endif // Each thread should have local inputs and outputs. // The inputs of each thread are all the same. std::vector outputs_tid; auto &predictor = predictors[tid]; +#ifdef PADDLE_WITH_MKLDNN + if (use_analysis) { + static_cast(predictor.get()) + ->SetMKLDNNThreadId(static_cast(tid) + 1); + } +#endif // warmup run LOG(INFO) << "Running thread " << tid << ", warm up run..."; -- GitLab