From 9788e5ab8795c9f5cf19f4fa7f7dd37a20512d12 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 27 Jul 2018 10:36:06 +0800 Subject: [PATCH] add flags to control num_threads --- .../fluid/inference/tests/book/test_inference_nlp.cc | 11 ++--------- paddle/fluid/platform/cpu_helper.cc | 3 +++ paddle/fluid/platform/init.cc | 4 +++- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 5cc1db12bb..e2a3e9d46e 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -20,9 +20,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/platform/cpu_helper.h" -#ifdef PADDLE_WITH_MKLML -#include -#endif DEFINE_string(model_path, "", "Directory of the inference model."); DEFINE_string(data_file, "", "File of input index data."); @@ -30,6 +27,7 @@ DEFINE_int32(repeat, 100, "Running the inference program repeat times"); DEFINE_bool(prepare_vars, true, "Prepare variables before executor"); DEFINE_int32(num_threads, 1, "Number of threads should be used"); DECLARE_bool(use_mkldnn); +DECLARE_int32(paddle_num_threads); inline double GetCurrentMs() { struct timeval time; @@ -160,12 +158,7 @@ TEST(inference, nlp) { std::unique_ptr scope( new paddle::framework::Scope()); -#ifdef PADDLE_WITH_MKLML - // only use 1 thread number per std::thread - omp_set_dynamic(0); - omp_set_num_threads(1); - paddle::platform::SetNumThreads(1); -#endif + paddle::platform::SetNumThreads(FLAGS_paddle_num_threads); double start_ms = 0, stop_ms = 0; if (FLAGS_num_threads > 1) { diff --git a/paddle/fluid/platform/cpu_helper.cc b/paddle/fluid/platform/cpu_helper.cc index 234a04b5c2..6841652b75 100644 --- a/paddle/fluid/platform/cpu_helper.cc +++ b/paddle/fluid/platform/cpu_helper.cc @@ -24,6 +24,9 @@ limitations under the License. */ #include #endif +DEFINE_int32(paddle_num_threads, 1, + "Number of threads for each paddle instance."); + namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 0b77652841..79a2ca96ed 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -23,6 +23,8 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" +DECLARE_int32(paddle_num_threads); + namespace paddle { namespace framework { @@ -115,7 +117,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); #ifndef PADDLE_WITH_MKLDNN - platform::SetNumThreads(1); + platform::SetNumThreads(FLAGS_paddle_num_threads); #endif } -- GitLab