From 1db3658429ef3d57e67e41261a175ff3bacfd701 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 30 Jun 2021 17:16:01 +0800 Subject: [PATCH] [pass_enhance] mul_gru_fuse_pass; fc_gru_fuse_pass (#33793) --- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 145 +++++++++++++++++- paddle/fluid/framework/ir/fc_gru_fuse_pass.h | 10 +- 2 files changed, 144 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 921e1ea5139..e1260f62ddb 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -30,8 +30,137 @@ namespace ir { class Node; -static int BuildFusion(Graph* graph, const std::string& name_scope, - Scope* scope, bool with_fc_bias) { +MulGRUFusePass::MulGRUFusePass() { + AddOpCompat(OpCompat("gru")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("H0") + .IsTensor() + .IsOptional() + .End() + .AddInput("Weight") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("BatchGate") + .IsTensor() + .End() + .AddOutput("BatchResetHiddenPrev") + .IsTensor() + .End() + .AddOutput("BatchHidden") + .IsTensor() + .End() + .AddOutput("Hidden") + .IsTensor() + .End() + .AddAttr("activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("gate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("is_reverse") + .IsType() + .End() + .AddAttr("origin_mode") + .IsType() + .IsOptional() + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +FCGRUFusePass::FCGRUFusePass() { + AddOpCompat(OpCompat("gru")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("H0") + .IsTensor() + .IsOptional() + .End() + .AddInput("Weight") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("BatchGate") + .IsTensor() + .End() + .AddOutput("BatchResetHiddenPrev") + .IsTensor() + .End() + .AddOutput("BatchHidden") + .IsTensor() + .End() + .AddOutput("Hidden") + .IsTensor() + .End() + .AddAttr("activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("gate_activation") + .IsStringIn({"sigmoid", "tanh", "relu", "identity"}) + .End() + .AddAttr("is_reverse") + .IsType() + .End() + .AddAttr("origin_mode") + .IsType() + .IsOptional() + .End(); + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(-1) + .End(); +} + +int FCGRUFusePass::BuildFusion(Graph* graph, const std::string& name_scope, + Scope* scope, bool with_fc_bias) const { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); @@ -133,6 +262,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } auto* x_n = subgraph.at(x); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); @@ -189,8 +322,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(name_scope_, graph); - int fusion_count = - BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/); + int fusion_count = MulGRUFusePass::BuildFusion( + graph, name_scope_, param_scope(), false /*with_fc_bias*/); AddStatis(fusion_count); } @@ -198,8 +331,8 @@ void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const { void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init(name_scope_, graph); - int fusion_count = - BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/); + int fusion_count = FCGRUFusePass::BuildFusion( + graph, name_scope_, param_scope(), true /*with_fc_bias*/); AddStatis(fusion_count); } diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.h b/paddle/fluid/framework/ir/fc_gru_fuse_pass.h index 73f00504d34..421f3ef46d7 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.h @@ -18,7 +18,6 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -26,21 +25,22 @@ namespace ir { // The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op. -class Graph; - class FCGRUFusePass : public FusePassBase { public: + FCGRUFusePass(); virtual ~FCGRUFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; - const std::string name_scope_{"fc_gru_fuse"}; + int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + bool with_fc_bias) const; }; // Just FC without bias -class MulGRUFusePass : public FusePassBase { +class MulGRUFusePass : public FCGRUFusePass { public: + MulGRUFusePass(); virtual ~MulGRUFusePass() {} protected: -- GitLab