diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc index 1cf1315d3d3059261d84d0e8795a75df4deae005..9a9314161b0e8d14a525d253572915ed597c716e 100644 --- a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/mkldnn_placement_pass.h" +#include namespace paddle { namespace framework { @@ -21,9 +22,16 @@ namespace ir { std::unique_ptr MKLDNNPlacementPass::ApplyImpl( std::unique_ptr graph) const { VLOG(3) << "Aplies MKL-DNN placement strategy."; + const auto& op_types_list = + Get>("mkldnn_enabled_op_types"); for (const Node* n : graph->Nodes()) { if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) { - n->Op()->SetAttr("use_mkldnn", true); + if (op_types_list.empty()) { + n->Op()->SetAttr("use_mkldnn", true); + } else if (std::find(op_types_list.begin(), op_types_list.end(), + n->Name()) != op_types_list.end()) { + n->Op()->SetAttr("use_mkldnn", true); + } } } return graph; @@ -33,5 +41,5 @@ std::unique_ptr MKLDNNPlacementPass::ApplyImpl( } // namespace framework } // namespace paddle -REGISTER_PASS(mkldnn_placement_pass, - paddle::framework::ir::MKLDNNPlacementPass); +REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass) + .RequirePassAttr("mkldnn_enabled_op_types"); diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 53cc7039f20aa83bd043f71af4afc14b10803552..83d411eecf6d706615243fd78cb7e4330d904fc1 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -116,6 +116,10 @@ struct Argument { DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses, std::vector); + // Pass a set of op types to enable its mkldnn kernel + DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes, + std::unordered_set); + DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index fce5e1cac92064a320179243380ea02b2c5d7838..51bca8039d4531536cd7a3c39ef8a27f1a5412a1 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -63,6 +63,11 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass_num++; } + if (pass_name == "mkldnn_placement_pass") { + pass->Set("mkldnn_enabled_op_types", + new std::unordered_set( + argument->mkldnn_enabled_op_types())); + } if (pass_name == "tensorrt_subgraph_pass") { PADDLE_ENFORCE(argument->tensorrt_node_teller_valid()); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 384d1dc27d6355ae682ac19dc5a6dfceb3cbe9ff..dcefdd92f5157dce7426f2f3e4a2bc053ce24775 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -49,6 +49,10 @@ contrib::AnalysisConfig::AnalysisConfig(const contrib::AnalysisConfig &other) { cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_; // fields from this. enable_ir_optim = other.enable_ir_optim; + // For mkldnn + use_mkldnn_ = other.use_mkldnn_; + mkldnn_enabled_op_types_ = other.mkldnn_enabled_op_types_; + use_feed_fetch_ops = other.use_feed_fetch_ops; use_tensorrt_ = other.use_tensorrt_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; @@ -77,6 +81,10 @@ contrib::AnalysisConfig::AnalysisConfig(contrib::AnalysisConfig &&other) { cpu_math_library_num_threads_ = other.cpu_math_library_num_threads_; // fields from this. enable_ir_optim = other.enable_ir_optim; + // For mkldnn + use_mkldnn_ = other.use_mkldnn_; + mkldnn_enabled_op_types_ = other.mkldnn_enabled_op_types_; + use_feed_fetch_ops = other.use_feed_fetch_ops; use_tensorrt_ = other.use_tensorrt_; tensorrt_max_batchsize_ = other.tensorrt_max_batchsize_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 84f7eca05703b9a2da51edcc49e2deda5fd74273..be51e7fc1f01c5fc4a48c7f32db15bb82a5ddc07 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -327,6 +327,10 @@ void AnalysisPredictor::OptimizeInferenceProgram() { argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_); } + if (config_.use_mkldnn_) { + argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_); + } + auto passes = config_.pass_builder()->AllPasses(); if (!config_.enable_ir_optim) passes.clear(); argument_.SetIrAnalysisPasses(passes); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index a08e3d027e01d55ff4c433d6d36dc10b38a132a9..f05b9832da55f10b34eb2df914e443a478e5a4a4 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -16,6 +16,7 @@ #include #include #include +#include #include // Here we include some header files with relative paths, for that in deploy, @@ -53,6 +54,9 @@ struct AnalysisConfig : public NativeConfig { void EnableMKLDNN(); bool use_mkldnn() const { return use_mkldnn_; } + void SetMKLDNNOp(std::unordered_set op_list) { + mkldnn_enabled_op_types_ = op_list; + } // Specify the memory buffer of program and parameter void SetModelBuffer(const char* prog_buffer, size_t prog_buffer_size, @@ -64,6 +68,7 @@ struct AnalysisConfig : public NativeConfig { protected: bool use_tensorrt_{false}; bool use_mkldnn_{false}; + std::unordered_set mkldnn_enabled_op_types_; int tensorrt_workspace_size_; int tensorrt_max_batchsize_; std::unique_ptr pass_builder_; diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index e8abcfce05f0b527edf6b92f12beffd9a4c723c3..227e2ff45873fded45899146b97a7bee0c8ad763 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -194,6 +194,8 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); + std::unordered_set op_list = {"conv3d"}; + cfg.SetMKLDNNOp(op_list); } std::vector outputs; @@ -236,6 +238,8 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); + std::unordered_set op_list = {"conv3d"}; + cfg.SetMKLDNNOp(op_list); } std::vector> input_slots_all;