提交 01f0f168 编写于 作者: T tensor-tang

enable mkldnn in infer api

上级 33c21291
......@@ -77,6 +77,9 @@ bool AnalysisPredictor::Init(
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_.use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());
......
......@@ -106,6 +106,9 @@ bool NativePaddlePredictor::Init(
}
ctx_ = executor_->Prepare(*inference_program_, 0);
if (config_.use_mkldnn) {
executor_->EnableMKLDNN(*inference_program_);
}
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);
......
......@@ -45,7 +45,7 @@ class PaddleBuf {
PaddleBuf(void* data, size_t length)
: data_(data), length_(length), memory_owned_{false} {}
// Own memory.
PaddleBuf(size_t length)
explicit PaddleBuf(size_t length)
: data_(new char[length]), length_(length), memory_owned_(true) {}
// Resize to `length` bytes.
void Resize(size_t length);
......@@ -121,6 +121,8 @@ struct NativeConfig : public PaddlePredictor::Config {
bool use_gpu{false};
int device{0};
float fraction_of_gpu_memory{-1.f}; // Negative to notify initialization.
// MKLDNN related fields.
bool use_mkldnn{false};
// Specify the variable's name of each input.
bool specify_input_name{false};
......
......@@ -66,12 +66,13 @@ Record ProcessALine(const std::string &line) {
* Use the native and analysis fluid engine to inference the demo.
* ocr, mobilenet and se_resnext50
*/
void TestVisualPrediction() {
void TestVisualPrediction(bool use_mkldnn) {
std::unique_ptr<PaddlePredictor> predictor;
AnalysisConfig cfg;
cfg.param_file = FLAGS_infer_model + "/__params__";
cfg.prog_file = FLAGS_infer_model + "/__model__";
cfg.use_gpu = false;
cfg.use_mkldnn = use_mkldnn;
cfg.device = 0;
cfg.enable_ir_optim = true;
cfg.ir_passes.push_back("fc_gru_fuse_pass");
......@@ -163,7 +164,10 @@ void TestVisualPrediction() {
}
}
TEST(Analyzer_vis, analysis) { TestVisualPrediction(); }
TEST(Analyzer_vis, analysis) { TestVisualPrediction(/*use_mkldnn*/ false); }
TEST(Analyzer_vis, analysis_mkldnn) {
TestVisualPrediction(/*use_mkldnn*/ true);
}
} // namespace analysis
} // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册