diff --git a/lite/api/benchmark.cc b/lite/api/benchmark.cc index 63d498c41fe5eb265a65a7fe4e849ced8153530e..65f074d7160434bc1c140d1d8be86566b777073f 100644 --- a/lite/api/benchmark.cc +++ b/lite/api/benchmark.cc @@ -91,6 +91,8 @@ void OutputOptModel(const std::string& save_optimized_model_dir) { } std::vector vaild_places = { Place{TARGET(kARM), PRECISION(kFloat)}, + Place{TARGET(kARM), PRECISION(kInt32)}, + Place{TARGET(kARM), PRECISION(kInt64)}, }; config.set_valid_places(vaild_places); auto predictor = lite_api::CreatePaddlePredictor(config); diff --git a/lite/core/mir/fusion/conv_elementwise_fuser.cc b/lite/core/mir/fusion/conv_elementwise_fuser.cc index 22ec1fa0d22378adf3776c6bb391f50fde376b7a..f94da2f1b1fc0a0d4ca17718f9407a4a56c544fe 100644 --- a/lite/core/mir/fusion/conv_elementwise_fuser.cc +++ b/lite/core/mir/fusion/conv_elementwise_fuser.cc @@ -30,7 +30,8 @@ void ConvElementwiseFuser::BuildPattern() { auto* bias = VarNode("bias") ->assert_is_op_input("elementwise_add", "Y") ->AsInput() - ->assert_is_persistable_var(); + ->assert_is_persistable_var() + ->assert_only_one_output(); // create op nodes auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); diff --git a/lite/core/mir/pattern_matcher.cc b/lite/core/mir/pattern_matcher.cc index aaebf852b2ec519515e59655a57600f59ec6a2c3..6e3c71d44329e258bba1efecad669e76ad66c83a 100644 --- a/lite/core/mir/pattern_matcher.cc +++ b/lite/core/mir/pattern_matcher.cc @@ -364,6 +364,11 @@ PMNode *PMNode::assert_is_op() { return this; } +PMNode *PMNode::assert_only_one_output() { + asserts_.emplace_back([](const Node *x) { return x->outlinks.size() == 1; }); + return this; +} + PMNode *PMNode::assert_is_op(const std::string &op_type) { asserts_.emplace_back([op_type](const Node *x) { if (x && x->IsStmt()) { diff --git a/lite/core/mir/pattern_matcher.h b/lite/core/mir/pattern_matcher.h index 0cbfbd986ce743985fde64b8e71b9b0e2b135b9e..f73edd55ffe949fb5dcf6b97159fe5ab88516196 100644 --- a/lite/core/mir/pattern_matcher.h +++ b/lite/core/mir/pattern_matcher.h @@ -127,6 +127,7 @@ struct PMNode { PMNode* assert_is_var(); PMNode* assert_var_not_persistable(); PMNode* assert_is_persistable_var(); + PMNode* assert_only_one_output(); PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type,