diff --git a/lite/api/CMakeLists.txt b/lite/api/CMakeLists.txt index 5177783daee261241beb126ec624543f6dc75dc3..70239e94e7a3064fb383246623d05a2079dda1fa 100644 --- a/lite/api/CMakeLists.txt +++ b/lite/api/CMakeLists.txt @@ -16,8 +16,11 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) target_link_libraries(paddle_full_api_shared framework_proto) if(LITE_WITH_X86) - add_dependencies(paddle_full_api_shared xxhash) - target_link_libraries(paddle_full_api_shared xxhash) + add_dependencies(paddle_full_api_shared xxhash) + target_link_libraries(paddle_full_api_shared xxhash) + if (NOT LITE_ON_MODEL_OPTIMIZE_TOOL) + add_dependencies(paddle_full_api_shared dynload_mklml) + endif() endif() if(LITE_WITH_CUDA) target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive") diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index 6fa400db6da9f029c38b496cd70d593a876628c9..3e6e10103e9f3af51923459a5921f9781431f352 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -20,6 +20,12 @@ #include "lite/core/device_info.h" #include "lite/core/version.h" +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) +#include +#include "lite/backends/x86/mklml.h" +#endif + namespace paddle { namespace lite { @@ -33,6 +39,17 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { mode_ = config.power_mode(); threads_ = config.threads(); + +#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) + int num_threads = config.cpu_math_library_num_threads(); + int real_num_threads = num_threads > 1 ? num_threads : 1; + paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); + omp_set_num_threads(real_num_threads); + VLOG(3) << "set_cpu_math_library_math_threads() is set successfully and the " + "number of threads is:" + << num_threads; +#endif } std::unique_ptr CxxPaddleApiImpl::GetInput(int i) { diff --git a/lite/api/paddle_api.h b/lite/api/paddle_api.h index c578769bd5159d27ad43e4e93de33f601223004b..339117cd503247a91694d1a9ca63b930af5658de 100644 --- a/lite/api/paddle_api.h +++ b/lite/api/paddle_api.h @@ -133,6 +133,7 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file_; std::string param_file_; bool model_from_memory_{false}; + int cpu_math_library_math_threads_ = 1; public: void set_valid_places(const std::vector& x) { valid_places_ = x; } @@ -151,6 +152,13 @@ class LITE_API CxxConfig : public ConfigBase { std::string model_file() const { return model_file_; } std::string param_file() const { return param_file_; } bool model_from_memory() const { return model_from_memory_; } + + void set_cpu_math_library_math_threads(int threads) { + cpu_math_library_math_threads_ = threads; + } + int cpu_math_library_num_threads() const { + return cpu_math_library_math_threads_; + } }; /// MobileConfig is the config for the light weight predictor, it will skip diff --git a/lite/api/test_step_rnn_lite_x86.cc b/lite/api/test_step_rnn_lite_x86.cc index 5314c5ed75d862635a1b87cdad33bf3c58dcd6cc..4d0aefbc06a9d0678d8b401629b7cc4355967f6c 100644 --- a/lite/api/test_step_rnn_lite_x86.cc +++ b/lite/api/test_step_rnn_lite_x86.cc @@ -30,6 +30,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { std::string model_dir = FLAGS_model_dir; lite_api::CxxConfig config; config.set_model_dir(model_dir); + config.set_cpu_math_library_math_threads(10); config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});