提交 9b2f1acf 编写于 作者: L liu zhengxi 提交者: GitHub

Add set_cpu_math_library_math_threads for CxxConfig (#2592)

* add set_cpu_math_library_math_threads for lite x86 platform, test=develop

* update the #if defined and add a condition LITE_WITH_X86, test=develop

* add if not defined LITE_ON_MODEL_OPTIMIZE_TOOL, test=develop
上级 0314aa1e
...@@ -16,8 +16,11 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE ...@@ -16,8 +16,11 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_CUDA OR LITE_WITH_X86 OR ARM_TARGE
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto) add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto)
target_link_libraries(paddle_full_api_shared framework_proto) target_link_libraries(paddle_full_api_shared framework_proto)
if(LITE_WITH_X86) if(LITE_WITH_X86)
add_dependencies(paddle_full_api_shared xxhash) add_dependencies(paddle_full_api_shared xxhash)
target_link_libraries(paddle_full_api_shared xxhash) target_link_libraries(paddle_full_api_shared xxhash)
if (NOT LITE_ON_MODEL_OPTIMIZE_TOOL)
add_dependencies(paddle_full_api_shared dynload_mklml)
endif()
endif() endif()
if(LITE_WITH_CUDA) if(LITE_WITH_CUDA)
target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive") target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive")
......
...@@ -20,6 +20,12 @@ ...@@ -20,6 +20,12 @@
#include "lite/core/device_info.h" #include "lite/core/device_info.h"
#include "lite/core/version.h" #include "lite/core/version.h"
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
#include <omp.h>
#include "lite/backends/x86/mklml.h"
#endif
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -33,6 +39,17 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -33,6 +39,17 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
mode_ = config.power_mode(); mode_ = config.power_mode();
threads_ = config.threads(); threads_ = config.threads();
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
int num_threads = config.cpu_math_library_num_threads();
int real_num_threads = num_threads > 1 ? num_threads : 1;
paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(real_num_threads);
VLOG(3) << "set_cpu_math_library_math_threads() is set successfully and the "
"number of threads is:"
<< num_threads;
#endif
} }
std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) { std::unique_ptr<lite_api::Tensor> CxxPaddleApiImpl::GetInput(int i) {
......
...@@ -133,6 +133,7 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -133,6 +133,7 @@ class LITE_API CxxConfig : public ConfigBase {
std::string model_file_; std::string model_file_;
std::string param_file_; std::string param_file_;
bool model_from_memory_{false}; bool model_from_memory_{false};
int cpu_math_library_math_threads_ = 1;
public: public:
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; } void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
...@@ -151,6 +152,13 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -151,6 +152,13 @@ class LITE_API CxxConfig : public ConfigBase {
std::string model_file() const { return model_file_; } std::string model_file() const { return model_file_; }
std::string param_file() const { return param_file_; } std::string param_file() const { return param_file_; }
bool model_from_memory() const { return model_from_memory_; } bool model_from_memory() const { return model_from_memory_; }
void set_cpu_math_library_math_threads(int threads) {
cpu_math_library_math_threads_ = threads;
}
int cpu_math_library_num_threads() const {
return cpu_math_library_math_threads_;
}
}; };
/// MobileConfig is the config for the light weight predictor, it will skip /// MobileConfig is the config for the light weight predictor, it will skip
......
...@@ -30,6 +30,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { ...@@ -30,6 +30,7 @@ TEST(Step_rnn, test_step_rnn_lite_x86) {
std::string model_dir = FLAGS_model_dir; std::string model_dir = FLAGS_model_dir;
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(model_dir); config.set_model_dir(model_dir);
config.set_cpu_math_library_math_threads(10);
config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册