提交 b97da301 编写于 作者: C cc 提交者: GitHub

Skip fusing conv and elementwise when the bias of elementwise has two output, test=develop (#3679)

上级 635d68a6
......@@ -91,6 +91,8 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
}
std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt32)},
Place{TARGET(kARM), PRECISION(kInt64)},
};
config.set_valid_places(vaild_places);
auto predictor = lite_api::CreatePaddlePredictor(config);
......
......@@ -30,7 +30,8 @@ void ConvElementwiseFuser::BuildPattern() {
auto* bias = VarNode("bias")
->assert_is_op_input("elementwise_add", "Y")
->AsInput()
->assert_is_persistable_var();
->assert_is_persistable_var()
->assert_only_one_output();
// create op nodes
auto* conv2d = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
......
......@@ -364,6 +364,11 @@ PMNode *PMNode::assert_is_op() {
return this;
}
PMNode *PMNode::assert_only_one_output() {
asserts_.emplace_back([](const Node *x) { return x->outlinks.size() == 1; });
return this;
}
PMNode *PMNode::assert_is_op(const std::string &op_type) {
asserts_.emplace_back([op_type](const Node *x) {
if (x && x->IsStmt()) {
......
......@@ -127,6 +127,7 @@ struct PMNode {
PMNode* assert_is_var();
PMNode* assert_var_not_persistable();
PMNode* assert_is_persistable_var();
PMNode* assert_only_one_output();
PMNode* assert_is_op_output(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册