diff --git a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc index 6f619181f4e10573e292b917269d97ad98a61a6b..84a4ff2de173d86184fcef53b8e55fe17958fb8c 100644 --- a/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/repeated_fc_relu_fuse_pass.cc @@ -94,7 +94,6 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, return false; } auto* relu_op = x->inputs[0]; - // std::cout << "xxxx" << std::endl; bool before_is_fc = relu_op->IsOp() && relu_op->inputs.size() == 1 && relu_op->inputs[0]->IsVar() && VarLinksFromOp(relu_op->inputs[0], "fc") && @@ -105,31 +104,18 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, } auto* fc_op = relu_op->inputs[0]->inputs[0]; bool is_fc = fc_op->IsOp() && fc_op->inputs.size() == 3; - // std::cout << "*****" << fc_op->inputs.size() << std::endl; if (!is_fc) { return false; } - for (size_t kkk = 0; kkk < 3; ++kkk) { - // std::cout << "++++++" << kkk << std::endl; - if (!fc_op->inputs[kkk]->inputs.empty()) { + for (auto* fc_i : fc_op->inputs) { + if (!fc_i->inputs.empty()) { if (at_top) { return true; } else { - bool res = VarLinksFromOp(fc_op->inputs[kkk], "relu"); - // std::cout << fc_op->inputs[kkk]->Name() << "++++++-----" << kkk << - // ":" - // << res << std::endl; - return res; + return VarLinksFromOp(fc_i, "relu"); } } } - // for (auto* fc_i : fc_op->inputs) { - // if (!fc_i->inputs.empty()) { - // std::cout << "++++++" << fc_op->inputs.size()< bool { for (int i = 0; i < repeated_times; ++i) { - // std::cout << "----" << i << std::endl; if (!var_before_is_fc_act(x, act_type, i == repeated_times - 1)) { return false; } @@ -180,17 +165,9 @@ PDNode* BuildRepeatedFCReluPattern(PDPattern* pattern, x, std::max(1, num_fc - i - 1), "relu"); } } else { - bool part1 = - var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && - x->inputs.size() > 0; - if (x->Name() == "fc_0.tmp_1" && x->IsVar() && part1) { - // std::cout << "testes" << std::endl; - } - bool part2 = var_before_is_fc_act_repeated_n_times(x, i, "relu"); - if (x->Name() == "fc_0.tmp_1") { - // std::cout << "========" << part1 << "," << part2 << std::endl; - } - return part1 && part2; + return var_next_is_fc_act_repeated_n_times(x, num_fc - i, "relu") && + x->inputs.size() > 0 && + var_before_is_fc_act_repeated_n_times(x, i, "relu"); } }, name_scope + "/fc_in_" + std::to_string(i)); @@ -394,7 +371,7 @@ std::unique_ptr RepeatedFCReluFusePass::ApplyImpl( int fusion_count = 0; for (int i = MAX_NUM_FC; i > 1; --i) { fusion_count += - BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(3), 3); + BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); } AddStatis(fusion_count); diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index fb4c5c0a00d0b693ac70691438e65ed9e49bd1a1..caebfe16d6fc8f344b1f337195fda79c485cfc9a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -190,8 +190,10 @@ void analysis_fuse_statis(bool use_zerocopy) { ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); + ASSERT_TRUE(fuse_statis.count("repeated_fc_relu")); + EXPECT_EQ(fuse_statis.at("repeated_fc_relu"), 2); LOG(INFO) << "num_ops: " << num_ops; - EXPECT_EQ(num_ops, 195); + EXPECT_EQ(num_ops, 185); } // Check the fuse status