From 4d774953c6cb584f084129746b4d2aea0e59237a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 6 Sep 2018 11:53:25 +0800 Subject: [PATCH] enable fc gru fuse pass --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 18 ++++++------- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 11 ++++---- paddle/fluid/inference/analysis/analyzer.h | 4 +++ .../inference/analysis/analyzer_lac_tester.cc | 25 +++++++++++++++++++ paddle/fluid/inference/api/CMakeLists.txt | 1 + 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f5235f70ad7..6c7f972589b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -24,6 +24,7 @@ pass_library(fc_fuse_pass) pass_library(attention_lstm_fuse_pass) pass_library(infer_clean_graph_pass) pass_library(fc_lstm_fuse_pass) +pass_library(fc_gru_fuse_pass) pass_library(seq_concat_fc_fuse_pass) set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 1e7b49620c0..4a08beee7d0 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -20,12 +20,8 @@ namespace paddle { namespace framework { namespace ir { -std::string GenNodeName(const std::string& prefix, const std::string& name) { - return prefix + "/" + name; -} - -void BuildPattern(PDPattern* pattern, const std::string& name_scope, - bool with_fc_bias) { +static void BuildPattern(PDPattern* pattern, const std::string& name_scope, + bool with_fc_bias) { PDNode* x = pattern->NewNode(name_scope, "x") ->assert_is_op_input("mul") ->assert_var_not_persistable(); @@ -35,8 +31,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope, VLOG(3) << "\n" << pattern->DotString(); } -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, - bool with_fc_bias) { +static int BuildFusion(Graph* graph, const std::string& name_scope, + Scope* scope, bool with_fc_bias) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -108,7 +104,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, auto* op = graph->CreateOpNode(&op_desc); PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); + // auto* scope = graph->Get(kParamScopeAttr); IR_NODE_LINK_TO(x_n, op); IR_NODE_LINK_TO(weight_x_n, op); @@ -189,5 +185,5 @@ std::unique_ptr FCGRUFusePass::ApplyImpl( } // namespace framework } // namespace paddle -REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulGRUFusePass); -REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCGRUFusePass); +REGISTER_PASS(mul_gru_fuse_pass, paddle::framework::ir::MulGRUFusePass); +REGISTER_PASS(fc_gru_fuse_pass, paddle::framework::ir::FCGRUFusePass); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 0d69dfa79aa..5fa3fcb9dc9 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -19,12 +19,13 @@ namespace paddle { namespace framework { namespace ir { -std::string GenNodeName(const std::string& prefix, const std::string& name) { +static std::string GenNodeName(const std::string& prefix, + const std::string& name) { return prefix + "/" + name; } -void BuildPattern(PDPattern* pattern, const std::string& name_scope, - bool with_fc_bias) { +static void BuildPattern(PDPattern* pattern, const std::string& name_scope, + bool with_fc_bias) { PDNode* x = pattern->NewNode(name_scope, "x") ->assert_is_op_input("mul") ->assert_var_not_persistable(); @@ -34,8 +35,8 @@ void BuildPattern(PDPattern* pattern, const std::string& name_scope, // LOG(INFO) << "\n" << pattern->DotString(); } -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, - bool with_fc_bias) { +static int BuildFusion(Graph* graph, const std::string& name_scope, + Scope* scope, bool with_fc_bias) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 3fdd2b9ec75..7800fc90b18 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -36,6 +36,8 @@ limitations under the License. */ */ #include +#include +#include #include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass_manager.h" @@ -66,6 +68,8 @@ class Analyzer : public OrderedRegistry { "attention_lstm_fuse_pass", // "fc_lstm_fuse_pass", // "mul_lstm_fuse_pass", // + "fc_gru_fuse_pass", // + "mul_gru_fuse_pass", // "seq_concat_fc_fuse_pass", // "fc_fuse_pass", // }}; diff --git a/paddle/fluid/inference/analysis/analyzer_lac_tester.cc b/paddle/fluid/inference/analysis/analyzer_lac_tester.cc index 5efee950309..a6e8351c4f9 100644 --- a/paddle/fluid/inference/analysis/analyzer_lac_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_lac_tester.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/analyzer.h" #include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/api/analysis_predictor.h" @@ -237,6 +238,30 @@ void TestLACPrediction(const std::string &model_path, for (size_t i = 0; i < size; ++i) { EXPECT_EQ(pdata_ref[i], pdata[i]); } + + AnalysisPredictor *analysis_predictor = + dynamic_cast(predictor.get()); + auto &fuse_statis = analysis_predictor->analysis_argument() + .Get>( + framework::ir::kFuseStatisAttr); + for (auto &item : fuse_statis) { + LOG(INFO) << "fused " << item.first << " " << item.second; + } + int num_ops = 0; + for (auto &node : + analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) { + if (node->IsFunction()) { + ++num_ops; + } + } + LOG(INFO) << "has num ops: " << num_ops; + ASSERT_TRUE(fuse_statis.count("fc_fuse")); + ASSERT_TRUE(fuse_statis.count("fc_gru_fuse")); + LOG(INFO) << "fc fuse num:" << fuse_statis.at("fc_fuse"); + LOG(INFO) << "fc gru fuse num:" << fuse_statis.at("fc_gru_fuse"); + + // ASSERT_TRUE(fuse_statis.count("fc_gru_fuse")); + // LOG(INFO) << fuse_statis.at("fc_gru_fuse"); } } diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index e976b9397d6..330ea044959 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -50,6 +50,7 @@ cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_a pass fc_fuse_pass fc_lstm_fuse_pass + fc_gru_fuse_pass seq_concat_fc_fuse_pass graph_viz_pass infer_clean_graph_pass -- GitLab