提交 b97da301 编写于 作者: C cc 提交者: GitHub

Skip fusing conv and elementwise when the bias of elementwise has two output, test=develop (#3679)

上级 635d68a6
...@@ -91,6 +91,8 @@ void OutputOptModel(const std::string& save_optimized_model_dir) { ...@@ -91,6 +91,8 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
} }
std::vector<Place> vaild_places = { std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt32)},
Place{TARGET(kARM), PRECISION(kInt64)},
}; };
config.set_valid_places(vaild_places); config.set_valid_places(vaild_places);
auto predictor = lite_api::CreatePaddlePredictor(config); auto predictor = lite_api::CreatePaddlePredictor(config);
......
...@@ -30,7 +30,8 @@ void ConvElementwiseFuser::BuildPattern() { ...@@ -30,7 +30,8 @@ void ConvElementwiseFuser::BuildPattern() {
auto* bias = VarNode("bias") auto* bias = VarNode("bias")
->assert_is_op_input("elementwise_add", "Y") ->assert_is_op_input("elementwise_add", "Y")
->AsInput() ->AsInput()
->assert_is_persistable_var(); ->assert_is_persistable_var()
->assert_only_one_output();
// create op nodes // create op nodes
auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
......
...@@ -364,6 +364,11 @@ PMNode *PMNode::assert_is_op() { ...@@ -364,6 +364,11 @@ PMNode *PMNode::assert_is_op() {
return this; return this;
} }
PMNode *PMNode::assert_only_one_output() {
asserts_.emplace_back([](const Node *x) { return x->outlinks.size() == 1; });
return this;
}
PMNode *PMNode::assert_is_op(const std::string &op_type) { PMNode *PMNode::assert_is_op(const std::string &op_type) {
asserts_.emplace_back([op_type](const Node *x) { asserts_.emplace_back([op_type](const Node *x) {
if (x && x->IsStmt()) { if (x && x->IsStmt()) {
......
...@@ -127,6 +127,7 @@ struct PMNode { ...@@ -127,6 +127,7 @@ struct PMNode {
PMNode* assert_is_var(); PMNode* assert_is_var();
PMNode* assert_var_not_persistable(); PMNode* assert_var_not_persistable();
PMNode* assert_is_persistable_var(); PMNode* assert_is_persistable_var();
PMNode* assert_only_one_output();
PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_output(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type, PMNode* assert_is_op_input(const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册