提交 01f0f168 编写于 作者: T tensor-tang

enable mkldnn in infer api

上级 33c21291
...@@ -77,6 +77,9 @@ bool AnalysisPredictor::Init( ...@@ -77,6 +77,9 @@ bool AnalysisPredictor::Init(
OptimizeInferenceProgram(); OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_.use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
VLOG(5) << "to create variables"; VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get()); PADDLE_ENFORCE(scope_.get());
......
...@@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init( ...@@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init(
} }
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_.use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
executor_->CreateVariables(*inference_program_, executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0); sub_scope_ ? sub_scope_ : scope_.get(), 0);
......
...@@ -45,7 +45,7 @@ class PaddleBuf { ...@@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf(void* data, size_t length) PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {} : data_(data), length_(length), memory_owned_{false} {}
// Own memory. // Own memory.
PaddleBuf(size_t length) explicit PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {} : data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes. // Resize to `length` bytes.
void Resize(size_t length); void Resize(size_t length);
...@@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config { ...@@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config {
bool use_gpu{false}; bool use_gpu{false};
int device{0}; int device{0};
float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization. float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization.
// MKLDNN related fields.
bool use_mkldnn{false};
// Specify the variable's name of each input. // Specify the variable's name of each input.
bool specify_input_name{false}; bool specify_input_name{false};
......
...@@ -66,12 +66,13 @@ Record ProcessALine(const std::string &line) { ...@@ -66,12 +66,13 @@ Record ProcessALine(const std::string &line) {
* Use the native and analysis fluid engine to inference the demo. * Use the native and analysis fluid engine to inference the demo.
* ocr, mobilenet and se_resnext50 * ocr, mobilenet and se_resnext50
*/ */
void TestVisualPrediction() { void TestVisualPrediction(bool use_mkldnn) {
std::unique_ptr<PaddlePredictor> predictor; std::unique_ptr<PaddlePredictor> predictor;
AnalysisConfig cfg; AnalysisConfig cfg;
cfg.param_file = FLAGS_infer_model + "/__params__"; cfg.param_file = FLAGS_infer_model + "/__params__";
cfg.prog_file = FLAGS_infer_model + "/__model__"; cfg.prog_file = FLAGS_infer_model + "/__model__";
cfg.use_gpu = false; cfg.use_gpu = false;
cfg.use_mkldnn = use_mkldnn;
cfg.device = 0; cfg.device = 0;
cfg.enable_ir_optim = true; cfg.enable_ir_optim = true;
cfg.ir_passes.push_back("fc_gru_fuse_pass"); cfg.ir_passes.push_back("fc_gru_fuse_pass");
...@@ -163,7 +164,10 @@ void TestVisualPrediction() { ...@@ -163,7 +164,10 @@ void TestVisualPrediction() {
} }
} }
TEST(Analyzer_vis, analysis) { TestVisualPrediction(); } TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
TEST(Analyzer_vis, analysis_mkldnn) {
TestVisualPrediction(/*use_mkldnn*/ true);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册