提交 116979a4 编写于 作者: L luotao1

refine api name

test=develop
上级 e66b4c6b
...@@ -46,7 +46,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { ...@@ -46,7 +46,7 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) {
prog_file = other.prog_file; prog_file = other.prog_file;
param_file = other.param_file; param_file = other.param_file;
specify_input_name = other.specify_input_name; specify_input_name = other.specify_input_name;
cpu_num_threads_ = other.cpu_num_threads_; cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
// fields from this. // fields from this.
enable_ir_optim = other.enable_ir_optim; enable_ir_optim = other.enable_ir_optim;
use_feed_fetch_ops = other.use_feed_fetch_ops; use_feed_fetch_ops = other.use_feed_fetch_ops;
...@@ -73,6 +73,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { ...@@ -73,6 +73,7 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) {
prog_file = other.prog_file; prog_file = other.prog_file;
param_file = other.param_file; param_file = other.param_file;
specify_input_name = other.specify_input_name; specify_input_name = other.specify_input_name;
cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_;
// fields from this. // fields from this.
enable_ir_optim = other.enable_ir_optim; enable_ir_optim = other.enable_ir_optim;
use_feed_fetch_ops = other.use_feed_fetch_ops; use_feed_fetch_ops = other.use_feed_fetch_ops;
......
...@@ -66,7 +66,7 @@ bool AnalysisPredictor::Init( ...@@ -66,7 +66,7 @@ bool AnalysisPredictor::Init(
#endif #endif
// no matter with or without MKLDNN // no matter with or without MKLDNN
paddle::platform::SetNumThreads(config_.GetCPUNumThreads()); paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
if (!PrepareScope(parent_scope)) { if (!PrepareScope(parent_scope)) {
return false; return false;
...@@ -159,7 +159,7 @@ bool AnalysisPredictor::PrepareExecutor() { ...@@ -159,7 +159,7 @@ bool AnalysisPredictor::PrepareExecutor() {
return true; return true;
} }
void AnalysisPredictor::SetMKLDNNThreadId(int tid) { void AnalysisPredictor::SetMkldnnThreadID(int tid) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::set_cur_thread_id(tid); platform::set_cur_thread_id(tid);
#else #else
......
...@@ -69,7 +69,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -69,7 +69,7 @@ class AnalysisPredictor : public PaddlePredictor {
framework::Scope *scope() { return scope_.get(); } framework::Scope *scope() { return scope_.get(); }
framework::ProgramDesc &program() { return *inference_program_; } framework::ProgramDesc &program() { return *inference_program_; }
void SetMKLDNNThreadId(int tid); void SetMkldnnThreadID(int tid);
protected: protected:
bool PrepareProgram(const std::shared_ptr<framework::ProgramDesc> &program); bool PrepareProgram(const std::shared_ptr<framework::ProgramDesc> &program);
......
...@@ -75,7 +75,7 @@ bool NativePaddlePredictor::Init( ...@@ -75,7 +75,7 @@ bool NativePaddlePredictor::Init(
#endif #endif
// no matter with or without MKLDNN // no matter with or without MKLDNN
paddle::platform::SetNumThreads(config_.GetCPUNumThreads()); paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
if (config_.use_gpu) { if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device); place_ = paddle::platform::CUDAPlace(config_.device);
......
...@@ -187,14 +187,18 @@ struct NativeConfig : public PaddlePredictor::Config { ...@@ -187,14 +187,18 @@ struct NativeConfig : public PaddlePredictor::Config {
// `feeds` and `fetches` of the phase `save_inference_model`. // `feeds` and `fetches` of the phase `save_inference_model`.
bool specify_input_name{false}; bool specify_input_name{false};
// Set and get the number of cpu threads. // Set and get the number of cpu math library threads.
void SetCPUNumThreads(int cpu_num_threads) { void SetCpuMathLibraryNumThreads(int cpu_math_library_num_threads) {
cpu_num_threads_ = cpu_num_threads; cpu_math_library_num_threads_ = cpu_math_library_num_threads;
}
int cpu_math_library_num_threads() const {
return cpu_math_library_num_threads_;
} }
int GetCPUNumThreads() const { return cpu_num_threads_; }
protected: protected:
int cpu_num_threads_{1}; // number of cpu threads for each instance. // number of cpu math library (such as MKL, OpenBlas) threads for each
// instance.
int cpu_math_library_num_threads_{1};
}; };
// A factory to help create different predictors. // A factory to help create different predictors.
......
...@@ -27,7 +27,7 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -27,7 +27,7 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->device = 0; cfg->device = 0;
cfg->enable_ir_optim = true; cfg->enable_ir_optim = true;
cfg->specify_input_name = true; cfg->specify_input_name = true;
cfg->SetCPUNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......
...@@ -54,7 +54,7 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) { ...@@ -54,7 +54,7 @@ std::ostream &operator<<(std::ostream &os, const NativeConfig &config) {
os << GenSpaces(num_spaces) os << GenSpaces(num_spaces)
<< "specify_input_name: " << config.specify_input_name << "\n"; << "specify_input_name: " << config.specify_input_name << "\n";
os << GenSpaces(num_spaces) os << GenSpaces(num_spaces)
<< "cpu_num_threads: " << config.GetCPUNumThreads() << "\n"; << "cpu_num_threads: " << config.cpu_math_library_num_threads() << "\n";
num_spaces--; num_spaces--;
os << GenSpaces(num_spaces) << "}\n"; os << GenSpaces(num_spaces) << "}\n";
return os; return os;
......
...@@ -221,7 +221,7 @@ void TestMultiThreadPrediction( ...@@ -221,7 +221,7 @@ void TestMultiThreadPrediction(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (use_analysis) { if (use_analysis) {
static_cast<AnalysisPredictor *>(predictor.get()) static_cast<AnalysisPredictor *>(predictor.get())
->SetMKLDNNThreadId(static_cast<int>(tid) + 1); ->SetMkldnnThreadID(static_cast<int>(tid) + 1);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册