From c3b70aece93d61759d5266e9f0112d0804fdf057 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Fri, 19 Oct 2018 05:09:09 +0200 Subject: [PATCH] Add MKL-DNN placement pass (#13958) * add MKL-DNN placement pass This patch also refactors conv+bn (includes changes from PR https://github.com/PaddlePaddle/Paddle/pull/13926) updated to use the mkldnn-placement-pass. test=develop * remove redundant pass list * add comment on the default first pass * fix test for conv+relu mkldnn fuse --- paddle/fluid/framework/ir/CMakeLists.txt | 10 ++- .../fluid/framework/ir/conv_bn_fuse_pass.cc | 86 ++++++++++++++----- .../ir/conv_relu_mkldnn_fuse_pass.cc | 6 ++ .../ir/conv_relu_mkldnn_fuse_pass_tester.cc | 47 +++++++--- paddle/fluid/framework/ir/fuse_pass_base.cc | 62 +++++++++++++ paddle/fluid/framework/ir/fuse_pass_base.h | 32 +++---- .../framework/ir/mkldnn_placement_pass.cc | 37 ++++++++ .../framework/ir/mkldnn_placement_pass.h | 31 +++++++ paddle/fluid/inference/analysis/analyzer.cc | 21 ++++- paddle/fluid/inference/analysis/analyzer.h | 6 ++ .../fluid/inference/api/analysis_predictor.cc | 22 ++++- .../inference/api/paddle_inference_api.h | 7 ++ 12 files changed, 301 insertions(+), 66 deletions(-) create mode 100644 paddle/fluid/framework/ir/fuse_pass_base.cc create mode 100644 paddle/fluid/framework/ir/mkldnn_placement_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn_placement_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 796ce1f91c..abab290e7d 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -10,7 +10,7 @@ function(pass_library TARGET DEST) set(oneValueArgs "") set(multiValueArgs SRCS DEPS) cmake_parse_arguments(op_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass ${op_library_DEPS}) + cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${op_library_DEPS}) # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") message(STATUS "add pass ${TARGET} ${DEST}") @@ -25,13 +25,11 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) +cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(fc_fuse_pass inference) -if (WITH_MKLDNN) - pass_library(conv_relu_mkldnn_fuse_pass inference) -endif () pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) pass_library(fc_lstm_fuse_pass inference) @@ -39,6 +37,10 @@ pass_library(embedding_fc_lstm_fuse_pass inference) pass_library(fc_gru_fuse_pass inference) pass_library(seq_concat_fc_fuse_pass inference) pass_library(conv_bn_fuse_pass inference) +if(WITH_MKLDNN) + pass_library(mkldnn_placement_pass base) + pass_library(conv_relu_mkldnn_fuse_pass inference) +endif() cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 04459612a7..846a14e365 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -126,12 +126,21 @@ std::unique_ptr ConvBNFusePass::ApplyImpl( // conv, batch_norm, // conv_weight, conv_out, // bn_scale, bn_bias, bn_mean, bn_variance, - // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, + // bn_saved_variance GET_CONV_BN_NODES(conv_bn_pattern); + // check if fuse can be done and if MKL-DNN should be used + FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm); + if (fuse_option == DO_NOT_FUSE) { + VLOG(3) << "do not perform conv+bn fuse"; + return; + } + // Create eltwise_y (conv bias) variable VarDesc eltwise_y_in_desc( patterns::PDNodeName(name_scope_, "eltwise_y_in")); + eltwise_y_in_desc.SetPersistable(true); auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); auto* eltwise_y_in_tensor = scope->Var(eltwise_y_in_node->Name())->GetMutable(); @@ -151,27 +160,59 @@ std::unique_ptr ConvBNFusePass::ApplyImpl( *bn_mean, *bn_variance, eltwise_y_in_tensor, epsilon); - // Create an elementwise add node - OpDesc desc; - desc.SetInput("X", std::vector({conv_out->Name()})); - desc.SetInput("Y", std::vector({eltwise_y_in_node->Name()})); - desc.SetOutput("Out", std::vector({bn_out->Name()})); - desc.SetType("elementwise_add"); - desc.SetAttr("axis", 1); - bool a = boost::get(conv->Op()->GetAttr("use_mkldnn")); - desc.SetAttr("use_mkldnn", a); - auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. - - GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance, - batch_norm, bn_mean_out, bn_variance_out, - bn_saved_mean, bn_saved_variance}); - - PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(conv_out, eltwise_op); - IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); - IR_NODE_LINK_TO(eltwise_op, bn_out); - - found_conv_bn_count++; + // with MKL-DNN fuse conv+bn into conv with bias + // without MKL-DNN fuse conv+bn into conv+elementwise_add + if (fuse_option == FUSE_MKLDNN) { + auto input_names = conv->Op()->InputNames(); + bool has_bias = std::find(input_names.begin(), input_names.end(), + "Bias") != input_names.end(); + if (has_bias && conv->Op()->Input("Bias").size() > 0) { + // reuse existing conv bias node + auto conv_bias_names = conv->Op()->Input("Bias"); + PADDLE_ENFORCE_EQ(conv_bias_names.size(), 1); + auto* conv_bias_var = scope->FindVar(conv_bias_names[0]); + auto* conv_bias_tensor = conv_bias_var->GetMutable(); + PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(), + eltwise_y_in_tensor->dims()); + + auto eigen_conv_bias = EigenVector::From(*conv_bias_tensor); + eigen_conv_bias += EigenVector::From(*eltwise_y_in_tensor); + } else { + // add new conv_bias node + conv->Op()->SetInput( + "Bias", std::vector({eltwise_y_in_node->Name()})); + IR_NODE_LINK_TO(eltwise_y_in_node, conv); + } + conv->Op()->SetOutput("Output", + std::vector({bn_out->Name()})); + + GraphSafeRemoveNodes( + graph.get(), + {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, + bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance}); + + IR_NODE_LINK_TO(conv, bn_out); + found_conv_bn_count++; + } else { // fuse_option == FUSE_NATIVE + // create an elementwise add node. + OpDesc desc; + desc.SetInput("X", std::vector({conv_out->Name()})); + desc.SetInput("Y", std::vector({eltwise_y_in_node->Name()})); + desc.SetOutput("Out", std::vector({bn_out->Name()})); + desc.SetType("elementwise_add"); + desc.SetAttr("axis", 1); + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. + + GraphSafeRemoveNodes( + graph.get(), + {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, + bn_variance_out, bn_saved_mean, bn_saved_variance}); + + IR_NODE_LINK_TO(conv_out, eltwise_op); + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); + IR_NODE_LINK_TO(eltwise_op, bn_out); + found_conv_bn_count++; + } }; gpd(graph.get(), handler); @@ -237,7 +278,6 @@ std::unique_ptr ConvEltwiseAddBNFusePass::ApplyImpl( {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); - PADDLE_ENFORCE(subgraph.count(conv_input)); IR_NODE_LINK_TO(eltwise, bn_out); found_conv_bn_count++; diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc index d7df6389cf..e359a3832e 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -46,6 +46,12 @@ std::unique_ptr ConvReLUFusePass::ApplyImpl( GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op + FuseOptions fuse_option = FindFuseOption(*conv, *relu); + if (fuse_option == DO_NOT_FUSE) { + VLOG(3) << "do not perform conv+relu fuse"; + return; + } + // Transform Conv node into ConvReLU node. OpDesc* desc = conv->Op(); desc->SetOutput("Output", std::vector({relu_out->Name()})); diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc index 9dd780ec89..8f4bab25ed 100644 --- a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc @@ -20,17 +20,19 @@ namespace paddle { namespace framework { namespace ir { -void SetOp(ProgramDesc* prog, const std::string& type, +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, const std::vector& inputs, - const std::vector& outputs) { + const std::vector& outputs, bool use_mkldnn = false) { auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); if (type == "conv2d") { - op->SetAttr("use_mkldnn", true); + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("name", name); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); op->SetInput("Bias", {inputs[2]}); } else if (type == "relu") { + op->SetAttr("use_mkldnn", use_mkldnn); op->SetInput("X", inputs); } op->SetOutput("Out", outputs); @@ -43,7 +45,8 @@ void SetOp(ProgramDesc* prog, const std::string& type, ProgramDesc BuildProgramDesc() { ProgramDesc prog; for (auto& v : - std::vector({"a", "b", "c", "weights", "bias", "f", "g"})) { + std::vector({"a", "b", "c", "weights", "bias", "f", "g", + "h", "weights2", "bias2", "k", "l"})) { auto* var = prog.MutableBlock(0)->Var(v); var->SetType(proto::VarType::SELECTED_ROWS); if (v == "weights" || v == "bias") { @@ -51,14 +54,24 @@ ProgramDesc BuildProgramDesc() { } } - SetOp(&prog, "OP0", std::vector({"a"}), + SetOp(&prog, "OP0", "op0", std::vector({"a"}), std::vector({"b"})); - SetOp(&prog, "OP1", std::vector({"b"}), + SetOp(&prog, "OP1", "op1", std::vector({"b"}), std::vector({"c"})); - SetOp(&prog, "conv2d", std::vector({"c", "weights", "bias"}), - std::vector({"f"})); - SetOp(&prog, "relu", std::vector({"f"}), - std::vector({"g"})); + // conv+relu, both with MKL-DNN + SetOp(&prog, "conv2d", "conv1", + std::vector({"c", "weights", "bias"}), + std::vector({"f"}), true); + SetOp(&prog, "relu", "relu1", std::vector({"f"}), + std::vector({"g"}), true); + SetOp(&prog, "OP3", "op3", std::vector({"g"}), + std::vector({"h"})); + // conv+relu, only one with MKL-DNN + SetOp(&prog, "conv2d", "conv2", + std::vector({"h", "weights2", "bias2"}), + std::vector({"k"}), true); + SetOp(&prog, "relu", "relu2", std::vector({"k"}), + std::vector({"l"})); return prog; } @@ -88,10 +101,16 @@ TEST(ConvReLUFusePass, basic) { auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); - ASSERT_TRUE(op->HasAttr("fuse_relu")); - bool fuse_relu = boost::get(op->GetAttr("fuse_relu")); - if (fuse_relu) { - ++conv_relu_count; + // check if only "conv1" convolution is fused + auto op_name = boost::get(op->GetAttr("name")); + if (op_name == "conv1") { + ASSERT_TRUE(op->HasAttr("fuse_relu")); + bool fuse_relu = boost::get(op->GetAttr("fuse_relu")); + if (fuse_relu) { + ++conv_relu_count; + } + } else if (op_name == "conv2") { + ASSERT_FALSE(op->HasAttr("fuse_relu")); } } } diff --git a/paddle/fluid/framework/ir/fuse_pass_base.cc b/paddle/fluid/framework/ir/fuse_pass_base.cc new file mode 100644 index 0000000000..d70010089e --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_pass_base.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +void FusePassBase::Init(const std::string& repr, Graph* graph) const { + repr_ = repr; + graph_ = graph; +} + +Scope* FusePassBase::param_scope() const { + PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); + return graph_->Get(kParamScopeAttr); +} + +void FusePassBase::AddStatis(int count_of_fused) const { + PADDLE_ENFORCE(graph_); + PADDLE_ENFORCE(!repr_.empty()); + if (!graph_->Has(kFuseStatisAttr)) { + graph_->Set(kFuseStatisAttr, new std::unordered_map); + } + auto& info = + graph_->Get>(kFuseStatisAttr); + info[repr_] = count_of_fused; +} + +FuseOptions FusePassBase::FindFuseOption(const Node& node1, + const Node& node2) const { +#ifdef PADDLE_WITH_MKLDNN + bool node1_mkldnn = node1.Op()->HasAttr("use_mkldnn") && + boost::get(node1.Op()->GetAttr("use_mkldnn")); + bool node2_mkldnn = node2.Op()->HasAttr("use_mkldnn") && + boost::get(node2.Op()->GetAttr("use_mkldnn")); + if (node1_mkldnn && node2_mkldnn) + return FUSE_MKLDNN; + else if (!node1_mkldnn && !node2_mkldnn) + return FUSE_NATIVE; + else + return DO_NOT_FUSE; +#else + return FUSE_NATIVE; +#endif +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h index 877bbeb502..c53b2a6186 100644 --- a/paddle/fluid/framework/ir/fuse_pass_base.h +++ b/paddle/fluid/framework/ir/fuse_pass_base.h @@ -25,32 +25,24 @@ namespace ir { static const char kParamScopeAttr[] = "__param_scope__"; static const char kFuseStatisAttr[] = "__fuse_statis__"; +enum FuseOptions { + DO_NOT_FUSE, // fusing will not be done + FUSE_NATIVE, // fusing will be done without MKL-DNN + FUSE_MKLDNN // fusing will be done with MKL-DNN +}; + class FusePassBase : public Pass { public: - void Init(const std::string& repr, Graph* graph) const { - repr_ = repr; - graph_ = graph; - } - - Scope* param_scope() const { - PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); - return graph_->Get(kParamScopeAttr); - } - - void AddStatis(int count_of_fused) const { - PADDLE_ENFORCE(graph_); - PADDLE_ENFORCE(!repr_.empty()); - if (!graph_->Has(kFuseStatisAttr)) { - graph_->Set(kFuseStatisAttr, new std::unordered_map); - } - auto& info = - graph_->Get>(kFuseStatisAttr); - info[repr_] = count_of_fused; - } + void Init(const std::string& repr, Graph* graph) const; + Scope* param_scope() const; + void AddStatis(int count_of_fused) const; virtual ~FusePassBase() {} protected: + virtual FuseOptions FindFuseOption(const Node& node1, + const Node& node2) const; + mutable Graph* graph_; mutable std::string repr_; }; diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc new file mode 100644 index 0000000000..65be69b7f5 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr MKLDNNPlacementPass::ApplyImpl( + std::unique_ptr graph) const { + VLOG(3) << "Aplies MKL-DNN placement strategy."; + for (const Node* n : graph->Nodes()) { + if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { + n->Op()->SetAttr("use_mkldnn", true); + } + } + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(mkldnn_placement_pass, + paddle::framework::ir::MKLDNNPlacementPass); diff --git a/paddle/fluid/framework/ir/mkldnn_placement_pass.h b/paddle/fluid/framework/ir/mkldnn_placement_pass.h new file mode 100644 index 0000000000..3d4dc9e2b6 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn_placement_pass.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class MKLDNNPlacementPass : public Pass { + protected: + std::unique_ptr ApplyImpl( + std::unique_ptr graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index d780592eb9..61d29d092e 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -101,7 +101,11 @@ Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } void Analyzer::Run(Argument* argument) { std::vector passes; - for (auto& pass : all_ir_passes_) { + if (use_mkldnn_) { + VLOG(3) << "Adding MKL-DNN placement pass"; + passes.push_back("mkldnn_placement_pass"); + } + for (auto& pass : ir_passes_) { if (!disabled_ir_passes_.count(pass)) { passes.push_back(pass); passes.push_back("graph_viz_pass"); // add graphviz for debug. @@ -117,11 +121,26 @@ void Analyzer::Run(Argument* argument) { } } +Analyzer& Analyzer::IncludeAllIrPasses() { + ir_passes_ = all_ir_passes_; + return *this; +} + Analyzer& Analyzer::DisableIrPasses(const std::vector& passes) { disabled_ir_passes_.insert(passes.begin(), passes.end()); return *this; } +Analyzer& Analyzer::IncludeIrPasses(const std::vector& passes) { + ir_passes_ = passes; + return *this; +} + +Analyzer& Analyzer::SetUseMkldnn(bool use_mkldnn) { + use_mkldnn_ = use_mkldnn; + return *this; +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 765145cb7d..6f45c6bf7e 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -54,6 +54,9 @@ class Analyzer : public OrderedRegistry { void Run(Argument* argument); Analyzer& DisableIrPasses(const std::vector& passes); + Analyzer& IncludeIrPasses(const std::vector& passes); + Analyzer& IncludeAllIrPasses(); + Analyzer& SetUseMkldnn(bool use_mkldnn); DISABLE_COPY_AND_ASSIGN(Analyzer); @@ -81,6 +84,9 @@ class Analyzer : public OrderedRegistry { }}; std::unordered_set disabled_ir_passes_; + // Ir passes to run + std::vector ir_passes_; + bool use_mkldnn_; }; } // namespace analysis diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 3095dee0f0..f1a4a4df50 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -225,10 +225,24 @@ void AnalysisPredictor::OptimizeInferenceProgram() { argument_.origin_program_desc.reset( new ProgramDesc(*inference_program_->Proto())); - PADDLE_ENFORCE( - config_.ir_mode == contrib::AnalysisConfig::IrPassMode::kExclude, - "Only kExclude is supported yet."); - Analyzer().DisableIrPasses(config_.ir_passes).Run(&argument_); + + switch (config_.ir_mode) { + case contrib::AnalysisConfig::IrPassMode::kExclude: + Analyzer() + .IncludeAllIrPasses() + .SetUseMkldnn(config_._use_mkldnn) + .DisableIrPasses(config_.ir_passes) + .Run(&argument_); + break; + case contrib::AnalysisConfig::IrPassMode::kInclude: + Analyzer() + .SetUseMkldnn(config_._use_mkldnn) + .IncludeIrPasses(config_.ir_passes) + .Run(&argument_); + break; + default: + LOG(ERROR) << "Only kExclude and kInclude modes are supoorted yet."; + } CHECK(argument_.transformed_program_desc); VLOG(5) << "to prepare executor"; diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index d2876dc27c..07ee6e72d1 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -259,10 +259,17 @@ struct AnalysisConfig : public NativeConfig { kExclude // Specify the disabled passes in `ir_passes`. }; + void SetIncludeMode() { + ir_mode = IrPassMode::kInclude; + // this pass has to be run at the beginning of all fuse passes + ir_passes = {"infer_clean_graph_pass"}; + } + // Determine whether to perform graph optimization. bool enable_ir_optim = true; // Manually determine the IR passes to run. IrPassMode ir_mode{IrPassMode::kExclude}; + // passes to be excluded/included std::vector ir_passes{"embedding_fc_lstm_fuse_pass"}; // NOT stable yet. -- GitLab