From b97da301fbbad869470e3fc4a474005a6651fdeb Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Fri, 22 May 2020 11:03:57 +0800 Subject: [PATCH] Skip fusing conv and elementwise when the bias of elementwise has two output, test=develop (#3679) --- lite/api/benchmark.cc | 2 ++ lite/core/mir/fusion/conv_elementwise_fuser.cc | 3 ++- lite/core/mir/pattern_matcher.cc | 5 +++++ lite/core/mir/pattern_matcher.h | 1 + 4 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index 63d498c41f..65f074d716 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -91,6 +91,8 @@ void OutputOptModel(const std::string& save_optimized_model_dir) { } std::vector vaild_places = { Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt32)}, + Place{TARGET(kARM), PRECISION(kInt64)}, }; config.set_valid_places(vaild_places); auto predictor = lite_api::CreatePaddlePredictor(config); diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc index 22ec1fa0d2..f94da2f1b1 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -30,7 +30,8 @@ void ConvElementwiseFuser::BuildPattern() { auto* bias = VarNode("bias") ->assert_is_op_input("elementwise_add", "Y") ->AsInput() - ->assert_is_persistable_var(); + ->assert_is_persistable_var() + ->assert_only_one_output(); // create op nodes auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index aaebf852b2..6e3c71d443 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -364,6 +364,11 @@ PMNode *PMNode::assert_is_op() { return this; } +PMNode *PMNode::assert_only_one_output() { + asserts_.emplace_back([](const Node *x) { return x->outlinks.size() == 1; }); + return this; +} + PMNode *PMNode::assert_is_op(const std::string &op_type) { asserts_.emplace_back([op_type](const Node *x) { if (x && x->IsStmt()) { diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 0cbfbd986c..f73edd55ff 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -127,6 +127,7 @@ struct PMNode { PMNode* assert_is_var(); PMNode* assert_var_not_persistable(); PMNode* assert_is_persistable_var(); + PMNode* assert_only_one_output(); PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type, -- GitLab