From 5632019f0f9160423f67104e8f333f8f1a05f238 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 17 Oct 2018 16:49:08 +0200 Subject: [PATCH] add MKL-DNN placement pass This patch also refactors conv+bn (includes changes from PR https://github.com/PaddlePaddle/Paddle/pull/13926) updated to use the mkldnn-placement-pass. test=develop --- paddle/fluid/inference/api/analysis_predictor.cc | 11 +++++++---- paddle/fluid/inference/api/paddle_inference_api.h | 4 +++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f1a4a4df506..531d4110dc2 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -226,18 +226,21 @@ void AnalysisPredictor::OptimizeInferenceProgram() { argument_.origin_program_desc.reset( new ProgramDesc(*inference_program_->Proto())); + bool use_mkldnn = config_._use_mkldnn; switch (config_.ir_mode) { case contrib::AnalysisConfig::IrPassMode::kExclude: Analyzer() .IncludeAllIrPasses() - .SetUseMkldnn(config_._use_mkldnn) - .DisableIrPasses(config_.ir_passes) + .SetUseMkldnn(use_mkldnn) + .DisableIrPasses(use_mkldnn ? config_.ir_mkldnn_passes + : config_.ir_passes) .Run(&argument_); break; case contrib::AnalysisConfig::IrPassMode::kInclude: Analyzer() - .SetUseMkldnn(config_._use_mkldnn) - .IncludeIrPasses(config_.ir_passes) + .SetUseMkldnn(use_mkldnn) + .IncludeIrPasses(use_mkldnn ? config_.ir_mkldnn_passes + : config_.ir_passes) .Run(&argument_); break; default: diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 07ee6e72d10..3416371fdbe 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -261,8 +261,8 @@ struct AnalysisConfig : public NativeConfig { void SetIncludeMode() { ir_mode = IrPassMode::kInclude; - // this pass has to be run at the beginning of all fuse passes ir_passes = {"infer_clean_graph_pass"}; + ir_mkldnn_passes = {"infer_clean_graph_pass"}; } // Determine whether to perform graph optimization. @@ -271,6 +271,8 @@ struct AnalysisConfig : public NativeConfig { IrPassMode ir_mode{IrPassMode::kExclude}; // passes to be excluded/included std::vector ir_passes{"embedding_fc_lstm_fuse_pass"}; + // passes to be excluded/included when MKL-DNN is enabled + std::vector ir_mkldnn_passes{"embedding_fc_lstm_fuse_pass"}; // NOT stable yet. bool use_feed_fetch_ops{true}; -- GitLab