提交 6c01485f 编写于 作者: L liu zhengxi 提交者: GitHub

[X86] Alter the api name to set_x86_math_library_math_threads (#2720)

* alter the api name from cpu to x86, test=develop

* correct the step_rnn model test, test=develop
上级 2134b7ab
...@@ -42,11 +42,11 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { ...@@ -42,11 +42,11 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) !(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
int num_threads = config.cpu_math_library_num_threads(); int num_threads = config.x86_math_library_num_threads();
int real_num_threads = num_threads > 1 ? num_threads : 1; int real_num_threads = num_threads > 1 ? num_threads : 1;
paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads);
omp_set_num_threads(real_num_threads); omp_set_num_threads(real_num_threads);
VLOG(3) << "set_cpu_math_library_math_threads() is set successfully and the " VLOG(3) << "set_x86_math_library_math_threads() is set successfully and the "
"number of threads is:" "number of threads is:"
<< num_threads; << num_threads;
#endif #endif
......
...@@ -133,7 +133,9 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -133,7 +133,9 @@ class LITE_API CxxConfig : public ConfigBase {
std::string model_file_; std::string model_file_;
std::string param_file_; std::string param_file_;
bool model_from_memory_{false}; bool model_from_memory_{false};
int cpu_math_library_math_threads_ = 1; #ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1;
#endif
public: public:
void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; } void set_valid_places(const std::vector<Place>& x) { valid_places_ = x; }
...@@ -153,12 +155,14 @@ class LITE_API CxxConfig : public ConfigBase { ...@@ -153,12 +155,14 @@ class LITE_API CxxConfig : public ConfigBase {
std::string param_file() const { return param_file_; } std::string param_file() const { return param_file_; }
bool model_from_memory() const { return model_from_memory_; } bool model_from_memory() const { return model_from_memory_; }
void set_cpu_math_library_num_threads(int threads) { #ifdef LITE_WITH_X86
cpu_math_library_math_threads_ = threads; void set_x86_math_library_num_threads(int threads) {
x86_math_library_math_threads_ = threads;
} }
int cpu_math_library_num_threads() const { int x86_math_library_num_threads() const {
return cpu_math_library_math_threads_; return x86_math_library_math_threads_;
} }
#endif
}; };
/// MobileConfig is the config for the light weight predictor, it will skip /// MobileConfig is the config for the light weight predictor, it will skip
......
...@@ -30,7 +30,9 @@ TEST(Step_rnn, test_step_rnn_lite_x86) { ...@@ -30,7 +30,9 @@ TEST(Step_rnn, test_step_rnn_lite_x86) {
std::string model_dir = FLAGS_model_dir; std::string model_dir = FLAGS_model_dir;
lite_api::CxxConfig config; lite_api::CxxConfig config;
config.set_model_dir(model_dir); config.set_model_dir(model_dir);
config.set_cpu_math_library_num_threads(1); #ifdef LITE_WITH_X86
config.set_x86_math_library_num_threads(1);
#endif
config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)}, config.set_valid_places({lite_api::Place{TARGET(kX86), PRECISION(kInt64)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}, lite_api::Place{TARGET(kX86), PRECISION(kFloat)},
lite_api::Place{TARGET(kHost), PRECISION(kFloat)}}); lite_api::Place{TARGET(kHost), PRECISION(kFloat)}});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册