From b8992673c6e146a548fabf0e856ab7eda24f3b51 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Wed, 18 Dec 2019 13:30:08 +0800 Subject: [PATCH] Add set_cpu_math_library_math_threads for CxxConfig (#2592) * add set_cpu_math_library_math_threads for lite x86 platform, test=develop * update the #if defined and add a condition LITE_WITH_X86, test=develop * add if not defined LITE_ON_MODEL_OPTIMIZE_TOOL, test=develop --- lite/api/CMakeLists.txt | 7 +++++-- lite/api/cxx_api_impl.cc | 17 +++++++++++++++++ lite/api/paddle_api.h | 8 ++++++++ lite/api/test_step_rnn_lite_x86.cc | 1 + 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 5177783dae..70239e94e7 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -16,8 +16,11 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) target_link_libraries(paddle_full_api_shared framework_proto) if(LITE_WITH_X86) - add_dependencies(paddle_full_api_shared xxhash) - target_link_libraries(paddle_full_api_shared xxhash) + add_dependencies(paddle_full_api_shared xxhash) + target_link_libraries(paddle_full_api_shared xxhash) + if (NOT LITE_ON_MODEL_OPTIMIZE_TOOL) + add_dependencies(paddle_full_api_shared dynload_mklml) + endif() endif() if(LITE_WITH_CUDA) target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive") diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 6fa400db6d..3e6e10103e 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -20,6 +20,12 @@ #include "lite/core/device_info.h" #include "lite/core/version.h" +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) +#include +#include "lite/backends/x86/mklml.h" +#endif + namespace paddle { namespace lite { @@ -33,6 +39,17 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { mode_ = config.power_mode(); threads_ = config.threads(); + +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) + int num_threads = config.cpu_math_library_num_threads(); + int real_num_threads = num_threads > 1 ? num_threads : 1; + paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); + omp_set_num_threads(real_num_threads); + VLOG(3) << "set_cpu_math_library_math_threads() is set successfully and the " + "number of threads is:" + << num_threads; +#endif } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index c578769bd5..339117cd50 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -133,6 +133,7 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file_; std::string param_file_; bool model_from_memory_{false}; + int cpu_math_library_math_threads_ = 1; public: void set_valid_places(const std::vector& x) { valid_places_ = x; } @@ -151,6 +152,13 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file() const { return model_file_; } std::string param_file() const { return param_file_; } bool model_from_memory() const { return model_from_memory_; } + + void set_cpu_math_library_math_threads(int threads) { + cpu_math_library_math_threads_ = threads; + } + int cpu_math_library_num_threads() const { + return cpu_math_library_math_threads_; + } }; /// MobileConfig is the config for the light weight predictor, it will skip diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc index 5314c5ed75..4d0aefbc06 100644 --- a/lite/api/test_step_rnn_lite_x86.cc +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -30,6 +30,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { std::string model_dir = FLAGS_model_dir; lite_api::CxxConfig config; config.set_model_dir(model_dir); + config.set_cpu_math_library_math_threads(10); config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); -- GitLab