diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index 96a60da518f9097f7eda27733fbd3355ea340a51..96a3b7ee058647156258b946c1301138c185fa31 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -76,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, std::vector seqpool_ops_input_var(num_inputs); std::vector seqpool_ops_output_var(num_inputs); + std::vector seqpool_ops_output_unused_var(num_inputs); std::vector seqpool_ops(num_inputs); for (int i = 0; i < num_inputs; ++i) { @@ -88,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, }, name_scope + "/sequence_pool_out_" + std::to_string(i)); + seqpool_ops_output_unused_var[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->inputs.size() == 1 && + x->outputs.size() == 0 && + is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0], + "SUM", i); + }, + name_scope + "/sequence_pool_unused_out_" + std::to_string(i)); + seqpool_ops[i] = pattern->NewNode( [=](Node* x) { return x && x->IsOp() && @@ -97,16 +107,23 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, seqpool_ops_input_var[i] = pattern->NewNode( [=](Node* x) { - return x && x->IsVar() && x->outputs.size() >= 1 && - is_seqpool_op_with_pootype_of_nth_input_of_concat( - x->outputs[0], "SUM", i); + bool basic = x && x->IsVar() && x->outputs.size() >= 1; + bool next_is_fine = false; + for (auto* o : x->outputs) { + if (is_seqpool_op_with_pootype_of_nth_input_of_concat(o, "SUM", + i)) { + next_is_fine = true; + break; + } + } + return basic && next_is_fine; }, name_scope + "/sequence_pool_in_" + std::to_string(i)); // Links seqpool_ops[i] ->LinksFrom({seqpool_ops_input_var[i]}) - .LinksTo({seqpool_ops_output_var[i]}); + .LinksTo({seqpool_ops_output_var[i], seqpool_ops_output_unused_var[i]}); } concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var}); return concat_out_var; diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc index 7d2739d84dea1a7ee92606a65c1aa6b2fdcb6c6a..456a03192cc4e4a9d0dbe2dcb649b6c1b4d9cd5a 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc @@ -35,11 +35,35 @@ void SetOp(ProgramDesc* prog, const std::string& type, op->SetInput("X", inputs); op->SetAttr("axis", 1); op->SetOutput("Out", {outputs[0]}); + } else { + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); } op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); } +int CountOpType(const ir::Graph* graph, + const std::string& op_type = "fusion_seqpool_concat") { + int count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == op_type) { + ++count; + } + } + return count; +} + +std::unique_ptr GetNumNodesOfBeforeAfter( + std::unique_ptr graph, int* before, int* after, + const std::string& pass_type = "seqpool_concat_fuse_pass") { + auto pass = PassRegistry::Instance().Get(pass_type); + *before = graph->Nodes().size(); + graph = pass->Apply(std::move(graph)); + *after = graph->Nodes().size(); + return graph; +} + /* * Before fuse: * a b c @@ -51,15 +75,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, * concat * | * j + * Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr + * * After fuse: * a b c * \ | / * fusion_seqpool_concat * | * j - * unused nodes: d, f, h */ -ProgramDesc BuildProgramDesc() { +TEST(SeqPoolConcatFusePass, basic) { ProgramDesc prog; for (auto& v : std::vector( {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) { @@ -76,35 +101,94 @@ ProgramDesc BuildProgramDesc() { SetOp(&prog, "concat", std::vector({"e", "g", "i"}), std::vector({"j"})); - return prog; -} - -TEST(SeqPoolConcatFusePass, basic) { - auto prog = BuildProgramDesc(); - std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op + // Add 1 Node: fusion_seqpool_concat + EXPECT_EQ(after, before - 9); + EXPECT_EQ(CountOpType(graph.get()), 1); +} - auto pass = PassRegistry::Instance().Get("seqpool_concat_fuse_pass"); - - int pre_nodes = graph->Nodes().size(); - - graph = pass->Apply(std::move(graph)); +/* + * Before fuse: + * a b + * | / \ + * op1 op2 op3 + * / \ / \ \ + * c d e f g + * \ / + * concat + * | + * h + * Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr + * + * After fuse: + * a b + * \ / \ + * fusion_seqpool_concat op3 + * | | + * h g + */ +TEST(SeqPoolConcatFusePass, advanced) { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "d", "e", "f", "g", "h"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } - int after_nodes = graph->Nodes().size(); + SetOp(&prog, "sequence_pool", std::vector({"a"}), + std::vector({"c", "d"})); + SetOp(&prog, "sequence_pool", std::vector({"b"}), + std::vector({"e", "f"})); + SetOp(&prog, "op3", std::vector({"b"}), + std::vector({"g"})); + SetOp(&prog, "concat", std::vector({"d", "f"}), + std::vector({"h"})); - // Remove 7 Nodes: op1, op2, op3, e, g, i, concat_op + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove 7 Nodes: op1, op2, c, d, e, f concat_op // Add 1 Node: fusion_seqpool_concat - EXPECT_EQ(pre_nodes - 6, after_nodes); + EXPECT_EQ(after, before - 6); + EXPECT_EQ(CountOpType(graph.get()), 1); +} - // Assert new op in newly generated graph - int count = 0; +ProgramDesc BuildProgramDesc(int num_inputs_of_concat) { + ProgramDesc prog; + auto new_var = [&](const std::string& name) { + auto* var = prog.MutableBlock(0)->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + }; + std::vector concat_inputs; + for (int i = 0; i < num_inputs_of_concat; ++i) { + std::string prefix = "seqpool_op_" + i; + new_var(prefix + "in"); + new_var(prefix + "out"); + new_var(prefix + "out_unused"); + SetOp(&prog, "sequence_pool", std::vector({prefix + "in"}), + std::vector({prefix + "out", prefix + "out_unused"})); + concat_inputs.push_back(prefix + "out"); + } + SetOp(&prog, "concat", concat_inputs, + std::vector({"concat_out"})); + return prog; +} - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "fusion_seqpool_concat") { - ++count; - } +// test more inputs of concat +TEST(SeqPoolConcatFusePass, more_inputs) { + for (int num : {1, 2, 10}) { + ProgramDesc prog = BuildProgramDesc(num); + std::unique_ptr graph(new ir::Graph(prog)); + int before, after; + graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after); + // Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op + // Add Node: fusion_seqpool_concat op + EXPECT_EQ(after, before - num * 3); + EXPECT_EQ(CountOpType(graph.get()), 1); } - EXPECT_EQ(count, 1); } } // namespace ir