diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f71a3d0f2e51f1a8d30fbc5f436edc97e80c57c1..a595a8ab4299298f625b8322a0adbed6d0b4fda3 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -69,6 +69,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc index 7dd6f4880abfacf80df3f0563f5b6e839a7150da..96a60da518f9097f7eda27733fbd3355ea340a51 100644 --- a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -112,8 +112,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, return concat_out_var; } -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, - int num_inputs) { +int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs); @@ -182,8 +181,8 @@ std::unique_ptr SeqPoolConcatFusePass::ApplyImpl( FusePassBase::Init(name_scope_, graph.get()); int fusion_count = 0; for (int i = MAX_CONCAT_INPUTS; i > 0; --i) { - fusion_count += BuildFusion( - graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i); + fusion_count += + BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i); } AddStatis(fusion_count); diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..7d2739d84dea1a7ee92606a65c1aa6b2fdcb6c6a --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass_tester.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h" +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "sequence_pool") { + op->SetInput("X", {inputs[0]}); + std::string pooltype = "SUM"; + op->SetAttr("pooltype", pooltype); + op->SetOutput("MaxIndex", {outputs[0]}); + op->SetOutput("Out", {outputs[1]}); + } else if (type == "concat") { + op->SetInput("X", inputs); + op->SetAttr("axis", 1); + op->SetOutput("Out", {outputs[0]}); + } + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +/* + * Before fuse: + * a b c + * | | | + * op1 op2 op3 + * / \ / \ / \ + * d e f g h i + * \ | / + * concat + * | + * j + * After fuse: + * a b c + * \ | / + * fusion_seqpool_concat + * | + * j + * unused nodes: d, f, h + */ +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : std::vector( + {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::LOD_TENSOR); + } + + SetOp(&prog, "sequence_pool", std::vector({"a"}), + std::vector({"d", "e"})); + SetOp(&prog, "sequence_pool", std::vector({"b"}), + std::vector({"f", "g"})); + SetOp(&prog, "sequence_pool", std::vector({"c"}), + std::vector({"h", "i"})); + SetOp(&prog, "concat", std::vector({"e", "g", "i"}), + std::vector({"j"})); + + return prog; +} + +TEST(SeqPoolConcatFusePass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("seqpool_concat_fuse_pass"); + + int pre_nodes = graph->Nodes().size(); + + graph = pass->Apply(std::move(graph)); + + int after_nodes = graph->Nodes().size(); + + // Remove 7 Nodes: op1, op2, op3, e, g, i, concat_op + // Add 1 Node: fusion_seqpool_concat + EXPECT_EQ(pre_nodes - 6, after_nodes); + + // Assert new op in newly generated graph + int count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "fusion_seqpool_concat") { + ++count; + } + } + EXPECT_EQ(count, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(seqpool_concat_fuse_pass);