From 7e964579d741572dcfb4759a6bcd779a47c29efd Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 30 Jun 2021 17:16:19 +0800 Subject: [PATCH] [pass_enhance] fc_lstm_fuse_pass; mul_lstm_fuse_pass (#33811) --- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 149 +++++++++++++++++- paddle/fluid/framework/ir/fc_lstm_fuse_pass.h | 7 +- 2 files changed, 152 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 6bd956ef0d..35704f1f33 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -29,8 +29,149 @@ namespace ir { class Node; -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, - bool with_fc_bias) { +MulLstmFusePass::MulLstmFusePass() { + AddOpCompat(OpCompat("lstm")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("H0") + .IsTensor() + .IsOptional() + .End() + .AddInput("C0") + .IsTensor() + .IsOptional() + .End() + .AddInput("Weight") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Hidden") + .IsTensor() + .End() + .AddOutput("Cell") + .IsTensor() + .End() + .AddOutput("BatchGate") + .IsTensor() + .End() + .AddOutput("BatchCellPreAct") + .IsTensor() + .End() + .AddAttr("use_peepholes") + .IsType() + .End() + .AddAttr("is_reverse") + .IsType() + .End() + .AddAttr("gate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("cell_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("candidate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +FCLstmFusePass::FCLstmFusePass() { + AddOpCompat(OpCompat("lstm")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("H0") + .IsTensor() + .IsOptional() + .End() + .AddInput("C0") + .IsTensor() + .IsOptional() + .End() + .AddInput("Weight") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Hidden") + .IsTensor() + .End() + .AddOutput("Cell") + .IsTensor() + .End() + .AddOutput("BatchGate") + .IsTensor() + .End() + .AddOutput("BatchCellPreAct") + .IsTensor() + .End() + .AddAttr("use_peepholes") + .IsType() + .End() + .AddAttr("is_reverse") + .IsType() + .End() + .AddAttr("gate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("cell_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("candidate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(-1) + .End(); +} + +int FCLstmFusePass::BuildFusion(Graph* graph, const std::string& name_scope, + Scope* scope, bool with_fc_bias) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -140,6 +281,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h index d37f53b15f..60b4953c2e 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h @@ -31,16 +31,19 @@ class Graph; class FCLstmFusePass : public FusePassBase { public: + FCLstmFusePass(); virtual ~FCLstmFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; - + int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + bool with_fc_bias) const; const std::string name_scope_{"fc_lstm_fuse"}; }; -class MulLstmFusePass : public FusePassBase { +class MulLstmFusePass : public FCLstmFusePass { public: + MulLstmFusePass(); virtual ~MulLstmFusePass() {} protected: -- GitLab