diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 6bd956ef0d53c989106157b54770d10156a2cefc..35704f1f3309e1a91b18d7a2c30ee7dda3b57e51 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -29,8 +29,149 @@ namespace ir { class Node; -int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, - bool with_fc_bias) { +MulLstmFusePass::MulLstmFusePass() { + AddOpCompat(OpCompat("lstm")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("H0") + .IsTensor() + .IsOptional() + .End() + .AddInput("C0") + .IsTensor() + .IsOptional() + .End() + .AddInput("Weight") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Hidden") + .IsTensor() + .End() + .AddOutput("Cell") + .IsTensor() + .End() + .AddOutput("BatchGate") + .IsTensor() + .End() + .AddOutput("BatchCellPreAct") + .IsTensor() + .End() + .AddAttr("use_peepholes") + .IsType() + .End() + .AddAttr("is_reverse") + .IsType() + .End() + .AddAttr("gate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("cell_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("candidate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +FCLstmFusePass::FCLstmFusePass() { + AddOpCompat(OpCompat("lstm")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("H0") + .IsTensor() + .IsOptional() + .End() + .AddInput("C0") + .IsTensor() + .IsOptional() + .End() + .AddInput("Weight") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Hidden") + .IsTensor() + .End() + .AddOutput("Cell") + .IsTensor() + .End() + .AddOutput("BatchGate") + .IsTensor() + .End() + .AddOutput("BatchCellPreAct") + .IsTensor() + .End() + .AddAttr("use_peepholes") + .IsType() + .End() + .AddAttr("is_reverse") + .IsType() + .End() + .AddAttr("gate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("cell_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("candidate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(-1) + .End(); +} + +int FCLstmFusePass::BuildFusion(Graph* graph, const std::string& name_scope, + Scope* scope, bool with_fc_bias) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -140,6 +281,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h index d37f53b15f06b72e67c234baec3a314f0f462735..60b4953c2ec0a8c225d74a604d74433f344b2424 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h @@ -31,16 +31,19 @@ class Graph; class FCLstmFusePass : public FusePassBase { public: + FCLstmFusePass(); virtual ~FCLstmFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; - + int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + bool with_fc_bias) const; const std::string name_scope_{"fc_lstm_fuse"}; }; -class MulLstmFusePass : public FusePassBase { +class MulLstmFusePass : public FCLstmFusePass { public: + MulLstmFusePass(); virtual ~MulLstmFusePass() {} protected: