diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 33ec9b4f47a9073cdf0a8fc2b4c17cd37561dcb6..43f6329083c16dd605ebc74a4aa31549516c2a8d 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -170,8 +170,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { void AppendOpFusePasses() { // 1. infernce pass if enabled. - AppendPassWithCheck(strategy_.inference_ && strategy_.del_dropout_, - "delete_dropout_op_x_pass"); + AppendPassWithCheck( + strategy_.enable_inference_pass_ && strategy_.delete_dropout_, + "delete_dropout_op_x_pass"); + AppendPassWithCheck( + strategy_.enable_inference_pass_ && strategy_.use_mkldnn_, + "mkldnn_placement_pass"); // 2. trainning pass AppendPassWithCheck(strategy_.fuse_relu_depthwise_conv_, diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 0ef89ae1eccc8afee30cf4003d3f9ee34024c966..513df4f19742d2cb98724c65d6276089866330c4 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -148,8 +148,13 @@ struct BuildStrategy { bool allow_cuda_graph_capture_{false}; // Inference pass - bool inference_{false}; // switch for infernce pass - bool del_dropout_{false}; + bool enable_inference_pass_{false}; // switch for infernce pass + bool delete_dropout_{true}; // delte dropout op +#ifdef PADDLE_WITH_MKLDNN + bool use_mkldnn_{true}; // use mkdnn to do inference +#else + bool use_mkldnn_{false}; // use mkdnn to do inference +#endif // FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode, // num_trainers is 1, so the current fields of build_strategy doesn't tell if diff --git a/paddle/fluid/jit/engine/pe_engine.cc b/paddle/fluid/jit/engine/pe_engine.cc index 35d7f87df74f6f29979228bf9461e5dccdbee039..2d35a8792ef70427fe8b109a5db964124a09c57a 100644 --- a/paddle/fluid/jit/engine/pe_engine.cc +++ b/paddle/fluid/jit/engine/pe_engine.cc @@ -74,8 +74,7 @@ PEEngine::PEEngine(const std::shared_ptr &info, void PEEngine::CreateGraphAndPE() { framework::details::BuildStrategy build_strategy; - build_strategy.inference_ = true; - build_strategy.del_dropout_ = true; + build_strategy.enable_inference_pass_ = true; // use pe to inference auto execution_strategy = GetExecutionStrategy(place_); auto &program_desc = info_->ProgramDesc();