diff --git a/paddle/fluid/inference/tests/book/test_inference_nlp.cc b/paddle/fluid/inference/tests/book/test_inference_nlp.cc index 5cc1db12bb71e428d493e7c6f718b1c6ed431858..e2a3e9d46ef9f303d191d59253ffbe9f4826184b 100644 --- a/paddle/fluid/inference/tests/book/test_inference_nlp.cc +++ b/paddle/fluid/inference/tests/book/test_inference_nlp.cc @@ -20,9 +20,6 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/platform/cpu_helper.h" -#ifdef PADDLE_WITH_MKLML -#include -#endif DEFINE_string(model_path, "", "Directory of the inference model."); DEFINE_string(data_file, "", "File of input index data."); @@ -30,6 +27,7 @@ DEFINE_int32(repeat, 100, "Running the inference program repeat times"); DEFINE_bool(prepare_vars, true, "Prepare variables before executor"); DEFINE_int32(num_threads, 1, "Number of threads should be used"); DECLARE_bool(use_mkldnn); +DECLARE_int32(paddle_num_threads); inline double GetCurrentMs() { struct timeval time; @@ -160,12 +158,7 @@ TEST(inference, nlp) { std::unique_ptr scope( new paddle::framework::Scope()); -#ifdef PADDLE_WITH_MKLML - // only use 1 thread number per std::thread - omp_set_dynamic(0); - omp_set_num_threads(1); - paddle::platform::SetNumThreads(1); -#endif + paddle::platform::SetNumThreads(FLAGS_paddle_num_threads); double start_ms = 0, stop_ms = 0; if (FLAGS_num_threads > 1) { diff --git a/paddle/fluid/platform/cpu_helper.cc b/paddle/fluid/platform/cpu_helper.cc index 234a04b5c2eb5ee643e8a4e723b28331cd8e6ee0..6841652b7571f53451ae1d39411f734a27f6c5d0 100644 --- a/paddle/fluid/platform/cpu_helper.cc +++ b/paddle/fluid/platform/cpu_helper.cc @@ -24,6 +24,9 @@ limitations under the License. */ #include #endif +DEFINE_int32(paddle_num_threads, 1, + "Number of threads for each paddle instance."); + namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 0b776528414735e8a7c1e3763e7ccb662bb9f285..79a2ca96ed7fdcefe83a6251781502af1d3cb6f2 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -23,6 +23,8 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/fluid/string/piece.h" +DECLARE_int32(paddle_num_threads); + namespace paddle { namespace framework { @@ -115,7 +117,7 @@ void InitDevices(bool init_p2p, const std::vector devices) { places.emplace_back(platform::CPUPlace()); platform::DeviceContextPool::Init(places); #ifndef PADDLE_WITH_MKLDNN - platform::SetNumThreads(1); + platform::SetNumThreads(FLAGS_paddle_num_threads); #endif }