From 72d2a1801e92cf441752a9701114c9584ccfcb10 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 7 Jan 2019 07:36:48 +0000 Subject: [PATCH] add seqpool concat fuse pass test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/seqpool_concat_fuse_pass.cc | 194 ++++++++++++++++++ .../framework/ir/seqpool_concat_fuse_pass.h | 38 ++++ .../fluid/inference/api/paddle_pass_builder.h | 1 + .../tests/api/analyzer_seq_pool1_tester.cc | 6 +- 5 files changed, 239 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6e6db3d3ef..f71a3d0f2e 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -42,6 +42,7 @@ pass_library(seq_concat_fc_fuse_pass inference) pass_library(multi_batch_merge_pass base) pass_library(conv_bn_fuse_pass inference) pass_library(seqconv_eltadd_relu_fuse_pass inference) +pass_library(seqpool_concat_fuse_pass inference) pass_library(is_test_pass base) pass_library(conv_elementwise_add_act_fuse_pass inference) pass_library(conv_elementwise_add2_act_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc new file mode 100644 index 0000000000..20b8220033 --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc @@ -0,0 +1,194 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h" +#include +#include +#include "paddle/fluid/framework/lod_tensor.h" + +#define MAX_CONCAT_INPUTS 200 + +namespace paddle { +namespace framework { +namespace ir { + +PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, + const std::string& name_scope, + int num_inputs) { + auto is_concat_op_with_inputs = [](Node* x, int num) -> bool { + return x && x->IsOp() && x->Op()->Type() == "concat" && + x->Op()->Input("X").size() == static_cast(num); + }; + + auto is_nth_input_var_of_concat = [=](Node* x, int idx) -> bool { + return x && x->IsVar() && VarLinksToOp(x, "concat") && + x->outputs.size() == 1 && IsNthInput(x, x->outputs[0], "X", idx) && + is_concat_op_with_inputs(x->outputs[0], num_inputs); + }; + + auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=]( + Node* x, const std::string& type, int idx) -> bool { + bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" && + x->Op()->HasAttr("pooltype") && + boost::get(x->Op()->GetAttr("pooltype")) == type && + x->outputs.size() == 2; // seqpool should only have 2 outputs + if (ok) { + // only one output of seqpool_op is nth_input_var of concat + // the other one should be unused empty var + if (is_nth_input_var_of_concat(x->outputs[0], idx)) { + ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0; + } else { + ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) && + x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0; + } + } + return ok; + }; + + auto* concat_op = pattern->NewNode( + [=](Node* x) { return is_concat_op_with_inputs(x, num_inputs); }, + name_scope + "/concat_op"); + concat_op->assert_op_attr("axis", 1); + + auto* concat_out_var = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && VarLinksFromOp(x, "concat") && + x->inputs.size() == 1 && + is_concat_op_with_inputs(x->inputs[0], num_inputs); + }, + name_scope + "/concat_out_var"); + concat_out_var->assert_is_only_output_of_op("concat"); + + std::vector seqpool_ops_input_var(num_inputs); + std::vector seqpool_ops_output_var(num_inputs); + std::vector seqpool_ops(num_inputs); + + for (int i = 0; i < num_inputs; ++i) { + seqpool_ops_output_var[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && is_nth_input_var_of_concat(x, i) && + x->inputs.size() == 1 && + is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0], + "SUM", i); + }, + name_scope + "/sequence_pool_out_" + std::to_string(i)); + + seqpool_ops[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && + is_seqpool_op_with_pootype_of_nth_input_of_concat(x, "SUM", i); + }, + name_scope + "/sequence_pool_op_" + std::to_string(i)); + + seqpool_ops_input_var[i] = pattern->NewNode( + [=](Node* x) { + return x && x->IsVar() && x->outputs.size() >= 1 && + is_seqpool_op_with_pootype_of_nth_input_of_concat( + x->outputs[0], "SUM", i); + }, + name_scope + "/sequence_pool_in_" + std::to_string(i)); + + // Links + seqpool_ops[i] + ->LinksFrom({seqpool_ops_input_var[i]}) + .LinksTo({seqpool_ops_output_var[i]}); + } + concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var}); + return concat_out_var; +} + +int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + int num_inputs) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs); + + auto retrieve_node = [](const std::string& name, + const GraphPatternDetector::subgraph_t& subgraph, + const PDPattern& pat) -> Node* { + PADDLE_ENFORCE(subgraph.count(pat.RetrieveNode(name)), + "pattern has no Node called %s", name.c_str()); + Node* p = subgraph.at(pat.RetrieveNode(name)); + PADDLE_ENFORCE_NOT_NULL(p, "subgraph has no node %s", name.c_str()); + return p; + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle SeqPool Concat fuse"; + std::vector input_names(num_inputs); + std::vector input_vars(num_inputs); + auto& fused_pattern = gpd.pattern(); + for (int i = 0; i < num_inputs; ++i) { + input_vars[i] = + retrieve_node(name_scope + "/sequence_pool_in_" + std::to_string(i), + subgraph, fused_pattern); + input_names[i] = input_vars[i]->Name(); + } + auto* concat_op = + retrieve_node(name_scope + "/concat_op", subgraph, fused_pattern); + auto* concat_out_var = + retrieve_node(name_scope + "/concat_out_var", subgraph, fused_pattern); + auto* seqpool_op0 = retrieve_node(name_scope + "/sequence_pool_op_0", + subgraph, fused_pattern); + + // Create New OpDesc + OpDesc op_desc; + op_desc.SetType("fusion_seqpool_concat"); + op_desc.SetInput("X", input_names); + op_desc.SetAttr("pooltype", seqpool_op0->Op()->GetAttr("pooltype")); + op_desc.SetAttr("axis", concat_op->Op()->GetAttr("axis")); + op_desc.SetOutput("Out", {concat_out_var->Name()}); + auto* op = graph->CreateOpNode(&op_desc); + for (size_t i = 0; i < input_vars.size(); ++i) { + IR_NODE_LINK_TO(input_vars[i], op); + } + IR_NODE_LINK_TO(op, concat_out_var); + + std::unordered_set marked_nodes; + for (auto& item : subgraph) { + marked_nodes.insert(item.second); + } + for (size_t i = 0; i < input_vars.size(); ++i) { + marked_nodes.erase(input_vars[i]); + } + marked_nodes.erase(concat_out_var); + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + + gpd(graph, handler); + return fusion_count; +} + +std::unique_ptr SeqPoolConcatFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + int fusion_count = 0; + for (int i = MAX_CONCAT_INPUTS; i > 0; --i) { + fusion_count += BuildFusion( + graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i); + } + AddStatis(fusion_count); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(seqpool_concat_fuse_pass, + paddle::framework::ir::SeqPoolConcatFusePass); diff --git a/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h new file mode 100644 index 0000000000..59730fde55 --- /dev/null +++ b/paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class SeqPoolConcatFusePass : public FusePassBase { + public: + virtual ~SeqPoolConcatFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"seqpool_concat_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 9337ae55b7..1e5712e163 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -89,6 +89,7 @@ class CpuPassStrategy : public PassStrategy { passes_.assign({ "infer_clean_graph_pass", // "attention_lstm_fuse_pass", // + "seqpool_concat_fuse_pass", // "seqconv_eltadd_relu_fuse_pass", // // "embedding_fc_lstm_fuse_pass", // "fc_lstm_fuse_pass", // diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index a1742f6068..083bdf15e9 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -177,8 +177,12 @@ TEST(Analyzer_seq_pool1, fuse_statis) { auto predictor = CreatePaddlePredictor(cfg); auto fuse_statis = GetFuseStatis( static_cast(predictor.get()), &num_ops); + + ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); + EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2); + LOG(INFO) << "num_ops: " << num_ops; - EXPECT_EQ(num_ops, 349); + EXPECT_EQ(num_ops, 195); } } // namespace analysis -- GitLab