From 478a4e850e6e4287495b4e3cf1ff5e8252ff557c Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 10 Sep 2018 13:26:07 +0800 Subject: [PATCH] refactor ir pattern (#13304) --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 33 ++-- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 106 +++++------- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 152 +++++++----------- .../framework/ir/graph_pattern_detector.cc | 123 +++++++------- .../framework/ir/graph_pattern_detector.h | 134 ++++++++++++++- .../framework/ir/seq_concat_fc_fuse_pass.cc | 6 + .../inference/analysis/analyzer_tester.cc | 2 + 7 files changed, 316 insertions(+), 240 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index 5a4ebd6f3..ca704c7f5 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -29,39 +29,27 @@ std::unique_ptr FCFusePass::ApplyImpl( std::unordered_set nodes2delete; GraphPatternDetector gpd; - // BuildFCPattern(gpd.mutable_pattern()); auto* x = gpd.mutable_pattern() ->NewNode("fc_fuse/x") ->AsInput() ->assert_is_op_input("mul", "X"); - patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/); - -#define GET_NODE(id) \ - PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \ - "pattern has no Node called %s", #id); \ - auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \ - PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id); + patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse"); + fc_pattern(x, true /*with bias*/); int found_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "handle FC fuse"; - // Currently, there is no FC op available, so I will just simulate the - // scenerio. - // FC's fusion is simple, just op fuse, no need to process the - // parameters. - GET_NODE(x); // x - GET_NODE(w); // Y - GET_NODE(fc_bias); // bias - GET_NODE(fc_out); // Out - GET_NODE(mul); // MUL op - GET_NODE(elementwise_add); // ELEMENT_ADD op - GET_NODE(mul_out); // tmp -#undef GET_NODE + GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); // Create an FC Node. OpDesc desc; - std::string fc_x_in = x->Name(); + std::string fc_x_in = subgraph.at(x)->Name(); std::string fc_Y_in = w->Name(); std::string fc_bias_in = fc_bias->Name(); std::string fc_out_out = fc_out->Name(); @@ -73,7 +61,8 @@ std::unique_ptr FCFusePass::ApplyImpl( auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); - IR_NODE_LINK_TO(x, fc_node); + PADDLE_ENFORCE(subgraph.count(x)); + IR_NODE_LINK_TO(subgraph.at(x), fc_node); IR_NODE_LINK_TO(w, fc_node); IR_NODE_LINK_TO(fc_bias, fc_node); IR_NODE_LINK_TO(fc_node, fc_out); diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 90d8d5c04..a902b0b50 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -20,52 +20,43 @@ namespace paddle { namespace framework { namespace ir { -static void BuildPattern(PDPattern* pattern, const std::string& name_scope, - bool with_fc_bias) { - PDNode* x = pattern->NewNode(name_scope, "x") - ->assert_is_op_input("mul") - ->assert_var_not_persistable(); - auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias); - fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. - patterns::GRU(pattern, name_scope, fc_out); - VLOG(3) << "fc_gru pattern \n" << pattern->DotString(); -} - static int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, bool with_fc_bias) { GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - BuildPattern(pattern, name_scope, with_fc_bias); + // Create pattern. + patterns::FC fc_pattern(pattern, name_scope); + patterns::GRU gru_pattern(pattern, name_scope); + + PDNode* x = + pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); + + auto* fc_out = fc_pattern(x, with_fc_bias); + fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. + gru_pattern(fc_out); // Create New OpDesc - auto gru_creater = [&](int gru, int x, int weight_x, int weight_h, int bias, - int hidden, int fc_bias) { -#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); - GET_NODE(x); - GET_NODE(weight_x); - GET_NODE(weight_h); - GET_NODE(bias); - GET_NODE(hidden); - GET_NODE(gru); + auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h, + Node* bias, Node* hidden, Node* fc_bias) { OpDesc op_desc; op_desc.SetType("fusion_gru"); #define NEW_NAME(x) name_scope + "/at." #x ".new" -#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); +#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()}); SET_IN(X, x); SET_IN(WeightX, weight_x); SET_IN(WeightH, weight_h); if (with_fc_bias) { - op_desc.SetInput("Bias", {NEW_NAME(bias) + bias_n->Name()}); + op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()}); } else { SET_IN(Bias, bias); } #undef SET_IN op_desc.SetInput("H0", {}); - op_desc.SetOutput("Hidden", {hidden_n->Name()}); - op_desc.SetAttr("is_reverse", gru_n->Op()->GetAttr("is_reverse")); + op_desc.SetOutput("Hidden", {hidden->Name()}); + op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse")); // TODO(TJ): This should be a option for infer op_desc.SetAttr("use_seq", true); @@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, PADDLE_ENFORCE(scope); if (with_fc_bias) { // Fusion GRU bias = fcbias + grubias - auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias_n->Name()); + auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name()); auto* out_bias_tensor = fusion_bias_var->GetMutable(); PADDLE_ENFORCE(fusion_bias_var); - GET_NODE(fc_bias); - PADDLE_ENFORCE(fc_bias_n); - auto* gru_bias_var = scope->FindVar(bias_n->Name()); - auto* fc_bias_var = scope->FindVar(fc_bias_n->Name()); + auto* gru_bias_var = scope->FindVar(bias->Name()); + auto* fc_bias_var = scope->FindVar(fc_bias->Name()); PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(fc_bias_var); const auto& gru_bias_tenosr = gru_bias_var->Get(); @@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, #undef NEW_NAME #undef NEW_IMTERMEDIATE_OUT - IR_NODE_LINK_TO(x_n, op); - IR_NODE_LINK_TO(weight_x_n, op); - IR_NODE_LINK_TO(weight_h_n, op); - IR_NODE_LINK_TO(bias_n, op); // actually should link to new bias if have - IR_NODE_LINK_TO(op, hidden_n); + IR_NODE_LINK_TO(x, op); + IR_NODE_LINK_TO(weight_x, op); + IR_NODE_LINK_TO(weight_h, op); + IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have + IR_NODE_LINK_TO(op, hidden); // h0? return op; }; @@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, int fusion_count{0}; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { -#define GET_NODE(name__) \ - std::string name__##key = name_scope + "/" + #name__; \ - auto* name__##n = pattern->RetrieveNode(name__##key); \ - PADDLE_ENFORCE(name__##n); \ - PADDLE_ENFORCE(subgraph.count(name__##n)); \ - Node* name__##_n = subgraph.at(name__##n); \ - int name__ __attribute__((unused)) = name__##_n->id(); - - GET_NODE(x); - GET_NODE(w); // fc weight - GET_NODE(mul); - GET_NODE(fc_out); - GET_NODE(Weight); - GET_NODE(gru); - GET_NODE(Bias); - GET_NODE(Hidden); + auto* x_n = subgraph.at(x); + GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern); // nodes need be removed - GET_NODE(BatchGate); - GET_NODE(BatchResetHiddenPrev); - GET_NODE(BatchHidden); + GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern); if (with_fc_bias) { - GET_NODE(mul_out); - GET_NODE(fc_bias); - GET_NODE(elementwise_add); - gru_creater(gru, x, w, Weight, Bias, Hidden, fc_bias); + GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); + + gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul_n, gru_n, elementwise_add_n, fc_bias_n, fc_out_n, mul_out_n, - BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n}); + {mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate, + BatchResetHiddenPrev, BatchHidden}); GraphSafeRemoveNodes(graph, marked_nodes); } else { - gru_creater(gru, x, w, Weight, Bias, Hidden, -1); + gru_creater(gru, x_n, w, Weight, Bias, Hidden, nullptr); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul_n, gru_n, BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n}); + {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden}); GraphSafeRemoveNodes(graph, marked_nodes); } #undef GET_NODE diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 3e0961369..f7fda8735 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -20,45 +20,29 @@ namespace paddle { namespace framework { namespace ir { -static std::string GenNodeName(const std::string& prefix, - const std::string& name) { - return prefix + "/" + name; -} +int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + bool with_fc_bias) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); -static void BuildPattern(PDPattern* pattern, const std::string& name_scope, - bool with_fc_bias) { - PDNode* x = pattern->NewNode(name_scope, "x") + // Build pattern + PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x")) ->assert_is_op_input("mul") ->assert_var_not_persistable(); - auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias); - fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. - patterns::LSTM(pattern, name_scope, fc_out); - // LOG(INFO) << "\n" << pattern->DotString(); -} - -static int BuildFusion(Graph* graph, const std::string& name_scope, - Scope* scope, bool with_fc_bias) { - GraphPatternDetector gpd; - auto* pattern = gpd.mutable_pattern(); + patterns::FC fc_pattern(pattern, name_scope); - BuildPattern(pattern, name_scope, with_fc_bias); + // fc_out is a tmp var, will be removed after fuse, so marked as intermediate. + auto* fc_out = fc_pattern(x, with_fc_bias)->AsIntermediate(); + patterns::LSTM lstm_pattern(pattern, name_scope); + lstm_pattern(fc_out); // Create New OpDesc - auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h, - int bias, int hidden, int cell, int xx, int fc_bias) { -#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); - GET_NODE(input); - GET_NODE(weight_x); - GET_NODE(weight_h); - GET_NODE(bias); - GET_NODE(hidden); - GET_NODE(cell); - GET_NODE(xx); - GET_NODE(lstm); - + auto lstm_creator = [&](Node* lstm, Node* input, Node* weight_x, + Node* weight_h, Node* bias, Node* hidden, Node* cell, + Node* xx, Node* fc_bias) { OpDesc op_desc; op_desc.SetType("fusion_lstm"); -#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); +#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()}); SET_IN(X, input); SET_IN(WeightX, weight_x); SET_IN(WeightH, weight_h); @@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto* bias_var = scope->Var(new_bias_var); PADDLE_ENFORCE(bias_var); auto* bias_tensor = bias_var->GetMutable(); - auto* lstm_bias_var = scope->FindVar(bias_n->Name()); + auto* lstm_bias_var = scope->FindVar(bias->Name()); PADDLE_ENFORCE(lstm_bias_var); const auto& lstm_bias_tensor = lstm_bias_var->Get(); bias_tensor->Resize(lstm_bias_tensor.dims()); - GET_NODE(fc_bias); - auto* fc_bias_var = scope->FindVar(fc_bias_n->Name()); + auto* fc_bias_var = scope->FindVar(fc_bias->Name()); const auto& fc_bias_tensor = fc_bias_var->Get(); auto* data = bias_tensor->mutable_data(platform::CPUPlace()); @@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, } op_desc.SetInput("Bias", {new_bias_var}); } -#undef GET_NODE // Create temp variables. - scope->Var(name_scope + "/BatchedInput.new") - ->GetMutable(); - scope->Var(name_scope + "/BatchCellPreAct.new") - ->GetMutable(); - scope->Var(name_scope + "/BatchedGate.new") - ->GetMutable(); + const std::string BatchedInput = patterns::UniqueKey("BatchedInput"); + const std::string BatchedCellPreAct = + patterns::UniqueKey("BatchedCellPreAct"); + const std::string BatchedGate = patterns::UniqueKey("BatchedGate"); + + scope->Var(BatchedInput)->GetMutable(); + scope->Var(BatchedCellPreAct)->GetMutable(); + scope->Var(BatchedGate)->GetMutable(); op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); - op_desc.SetOutput("Hidden", {hidden_n->Name()}); - op_desc.SetOutput("Cell", {cell_n->Name()}); - op_desc.SetOutput("XX", {xx_n->Name()}); - op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"}); - op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"}); - op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"}); - op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); - op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); + op_desc.SetOutput("Hidden", {hidden->Name()}); + op_desc.SetOutput("Cell", {cell->Name()}); + op_desc.SetOutput("XX", {xx->Name()}); + op_desc.SetOutput("BatchedGate", {BatchedGate}); + op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct}); + op_desc.SetOutput("BatchedInput", {BatchedInput}); + op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); + op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); // TODO(TJ): get from attr op_desc.SetAttr("use_seq", true); -#define TMP_NAME(x) "at.new.tmp." #x -#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)}) + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); + auto* scope = graph->Get(kParamScopeAttr); +#define OP_SET_OUT(x) \ + const std::string x = patterns::UniqueKey(#x); \ + op_desc.SetOutput(#x, {x}); \ + scope->Var(x)->GetMutable() OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedHidden); OP_SET_OUT(ReorderedH0); @@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, #undef OP_SET_OUT auto* op = graph->CreateOpNode(&op_desc); - PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); - auto* scope = graph->Get(kParamScopeAttr); - -#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable() - TMP_NEW(BatchedCell); - TMP_NEW(BatchedHidden); - TMP_NEW(ReorderedH0); - TMP_NEW(ReorderedC0); -#undef TMP_NEW -#undef TMP_NAME - - IR_NODE_LINK_TO(input_n, op); - IR_NODE_LINK_TO(weight_x_n, op); - IR_NODE_LINK_TO(weight_h_n, op); - IR_NODE_LINK_TO(bias_n, op); - IR_NODE_LINK_TO(op, hidden_n); + IR_NODE_LINK_TO(input, op); + IR_NODE_LINK_TO(weight_x, op); + IR_NODE_LINK_TO(weight_h, op); + IR_NODE_LINK_TO(bias, op); + IR_NODE_LINK_TO(op, hidden); return op; }; @@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { -#define GET_NODE(name__) \ - std::string name__##key = name_scope + "/" + #name__; \ - auto* name__##n = pattern->RetrieveNode(name__##key); \ - PADDLE_ENFORCE(name__##n); \ - PADDLE_ENFORCE(subgraph.count(name__##n)); \ - Node* name__##_n = subgraph.at(name__##n); \ - int name__ __attribute__((unused)) = name__##_n->id(); - - GET_NODE(x); - GET_NODE(w); - GET_NODE(mul); - GET_NODE(fc_out); - GET_NODE(Weight); - GET_NODE(lstm); - GET_NODE(Bias); - GET_NODE(Hidden); - GET_NODE(Cell); + GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); if (with_fc_bias) { - GET_NODE(fc_bias); - GET_NODE(elementwise_add); - lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); + lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, + fc_bias); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul_n, lstm_n, elementwise_add_n}); + {mul, lstm, elementwise_add}); GraphSafeRemoveNodes(graph, marked_nodes); } else { - lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1); + GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); + lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, + nullptr); // Remove unneeded nodes. - std::unordered_set marked_nodes({mul_n, lstm_n}); + std::unordered_set marked_nodes({mul, lstm}); GraphSafeRemoveNodes(graph, marked_nodes); } -#undef GET_NODE ++fusion_count; }; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5ca750951..fc7feca56 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/printf.h" namespace paddle { namespace framework { @@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { for (auto& pdnode : pattern_.nodes()) { if (!pdnodes2nodes_.count(pdnode.get())) { VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; - - return false; + // return false; } } for (auto& item : pdnodes2nodes_) { @@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { return false; } -PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, - PDNode* x, bool with_bias) { - // mul op - auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul"); - auto* mul_weight_var = pattern->NewNode(name_scope, "w") - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("mul", "Y"); - - PDNode* fc_out{nullptr}; - if (with_bias) { - PDNode* elementwise_add_op{nullptr}; - PDNode *mul_out_var{nullptr}, *bias{nullptr}; - elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add") - ->assert_is_op("elementwise_add"); - // intermediate variable, will be removed in the IR after fuse. - mul_out_var = pattern->NewNode(name_scope, "mul_out") - ->AsIntermediate() - ->assert_is_only_output_of_op("mul") - ->assert_is_op_input("elementwise_add"); - // bias - bias = pattern->NewNode(name_scope, "fc_bias") - ->AsInput() - ->assert_is_op_input("elementwise_add"); - // output - fc_out = pattern->NewNode(name_scope, "fc_out") - ->AsOutput() - ->assert_is_op_output("elementwise_add"); - mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var}); - elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); - } else { - fc_out = pattern->NewNode(name_scope, "fc_out") - ->AsOutput() - ->assert_is_op_output("mul"); - mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out}); +PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, + bool with_bias) { + // Create shared nodes. + x->assert_is_op_input("mul", "X"); + auto* mul = pattern->NewNode(mul_repr())->assert_is_op("mul"); + + auto* mul_w_var = pattern->NewNode(w_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("mul", "Y"); + + auto* mul_out_var = + pattern->NewNode(mul_out_repr())->assert_is_op_output("mul"); + + if (!with_bias) { // not with bias + // Add links. + mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var}); + return mul_out_var; + + } else { // with bias + mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add"); + // Create operators. + auto* elementwise_add = pattern->NewNode(elementwise_add_repr()) + ->assert_is_op("elementwise_add"); + // Create variables. + auto* bias = pattern->NewNode(bias_repr()) + ->assert_is_op_input("elementwise_add") + ->AsInput(); + + auto* fc_out = pattern->NewNode(Out_repr()) + ->AsOutput() + ->assert_is_op_output("elementwise_add"); + + mul->LinksFrom({mul_w_var, x}).LinksTo({mul_out_var}); + elementwise_add->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); + return fc_out; } - return fc_out; } -#define NEW_NODE(op__, arg__, io__) \ - auto* arg__ = pattern->NewNode(name_scope, #arg__) \ - ->assert_is_op_##io__(#op__, #arg__); - -PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope, - PDNode* x) { +PDNode* patterns::LSTM::operator()(PDNode* x) { x->assert_is_op_input("lstm", "Input"); - auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm"); + auto* lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm"); +#define NEW_NODE(arg__, io__) \ + auto* arg__ = \ + pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__); // Currently, the H0 and C0 are optional // TODO(Superjomn) upgrade the fuse framework to support optional. // NEW_NODE(H0, input); // NEW_NODE(C0, input); - NEW_NODE(lstm, Weight, input); - NEW_NODE(lstm, Bias, input); + NEW_NODE(Weight, input); + NEW_NODE(Bias, input); - NEW_NODE(lstm, Hidden, output); - NEW_NODE(lstm, Cell, output); - NEW_NODE(lstm, BatchGate, output); - NEW_NODE(lstm, BatchCellPreAct, output); + NEW_NODE(Hidden, output); + NEW_NODE(Cell, output); + NEW_NODE(BatchGate, output); + NEW_NODE(BatchCellPreAct, output); +#undef NEW_NODE lstm_op->LinksFrom({x, Weight, Bias}); lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct}); return Hidden; } -PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, - PDNode* x) { +PDNode* patterns::GRU::operator()(PDNode* x) { x->assert_is_op_input("gru", "Input"); - auto* gru_op = pattern->NewNode(name_scope, "gru")->assert_is_op("gru"); + auto* gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru"); +#define NEW_NODE(arg__, io__) \ + auto* arg__ = \ + pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__); - NEW_NODE(gru, Weight, input); + NEW_NODE(Weight, input); // TODO(Superjomn): upgrade the fuse framework to support optional. // H0 and bias are optional - NEW_NODE(gru, Bias, input); // also optional + NEW_NODE(Bias, input); // also optional // NEW_NODE(H0, input); - NEW_NODE(gru, Hidden, output); + NEW_NODE(Hidden, output); // below are intermediate - NEW_NODE(gru, BatchGate, output); - NEW_NODE(gru, BatchResetHiddenPrev, output); - NEW_NODE(gru, BatchHidden, output); + NEW_NODE(BatchGate, output); + NEW_NODE(BatchResetHiddenPrev, output); + NEW_NODE(BatchHidden, output); +#undef NEW_NODE BatchGate->AsIntermediate(); BatchResetHiddenPrev->AsIntermediate(); @@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden}); return Hidden; } -#undef NEW_NODE } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 71e4c36d9..57482a07b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph, const std::unordered_set& nodes); // Some pre-defined patterns those can be reused in multiple passes. +// The related Fluid Layer or Op should be one pattern here for better reusage +// accross different fusion. namespace patterns { +struct KeyCounter { + static KeyCounter& Instance() { + static KeyCounter x; + return x; + } + + int IncCounter(const std::string& key) { return dic_[key]++; } + + private: + std::unordered_map dic_; +}; + +// Generate a unique PDNode's name with name_scope and id. +// The format is {name_scope}/{repr}/{id}/{name} +static std::string PDNodeName(const std::string& name_scope, + const std::string& repr, size_t id, + const std::string& name) { + return string::Sprintf("%s/%s/%d/%s", name_scope, repr, id, name); +} +// Generate a unique PDNode's name. +// The format is {name_scope}/{repr}/{id} +static std::string PDNodeName(const std::string& name_scope, + const std::string& repr) { + return string::Sprintf("%s/%s/%d", name_scope, repr, + KeyCounter::Instance().IncCounter(repr)); +} +// Generate a unique key. It can be used for a universally unique temporary +// name. +// The format is {repr}/{id} +static std::string UniqueKey(const std::string& repr) { + return string::Sprintf("%s/%d", repr, + KeyCounter::Instance().IncCounter(repr)); +} + +// Declare a PDNode in a pattern, will create two methods: +// std::string xxx_repr(); return this PDNode's string id. +// PDNode* xxx_n(); return the corresponding PDNode. +#define PATTERN_DECL_NODE(name__) \ + std::string name__##_repr() const { \ + return PDNodeName(name_scope_, repr_, id_, #name__); \ + } \ + PDNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); } + +// Get an ir::Node* from the matched subgraph. +// var: variable. +// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition. +// pat: the pattern object. +#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \ + PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \ + "Node not found for PDNode %s", pat.arg##_repr()); \ + Node* var = subgraph.at(pat.arg##_n()); \ + PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg) + +// The base class of all the patterns. +struct PatternBase { + PatternBase(PDPattern* pattern, const std::string& name_scope, + const std::string& repr) + : pattern(pattern), + name_scope_(name_scope), + repr_(repr), + id_(KeyCounter::Instance().IncCounter(repr)) {} + + PDPattern* pattern; + + protected: + std::string name_scope_; + std::string repr_; + size_t id_; +}; + // FC with bias // op: mul + elementwise_add // named nodes: // mul, elementwise_add // w, mul_out, bias, fc_out -PDNode* FC(PDPattern* pattern, const std::string& name_scope, PDNode* x, - bool with_bias); +struct FC : public PatternBase { + FC(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fc") {} + + PDNode* operator()(PDNode* x, bool with_bias); + + // declare operator node's name + PATTERN_DECL_NODE(fc); + PATTERN_DECL_NODE(mul); + PATTERN_DECL_NODE(elementwise_add); + // declare variable node's name + PATTERN_DECL_NODE(w); + PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out + PATTERN_DECL_NODE(bias); + PATTERN_DECL_NODE(Out); +}; + +struct LSTM : public PatternBase { + LSTM(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "lstm") {} -PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x); + PDNode* operator()(PDNode* x); -PDNode* GRU(PDPattern* pattern, const std::string& name_scope, PDNode* x); + // Operators + PATTERN_DECL_NODE(lstm); + + // Inputs + PATTERN_DECL_NODE(Input); + PATTERN_DECL_NODE(H0); + PATTERN_DECL_NODE(C0); + PATTERN_DECL_NODE(Weight); + PATTERN_DECL_NODE(Bias); + + // Outputs + PATTERN_DECL_NODE(Hidden); + PATTERN_DECL_NODE(Cell); + PATTERN_DECL_NODE(BatchGate); + PATTERN_DECL_NODE(BatchCellPreAct); +}; + +struct GRU : public PatternBase { + GRU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "lstm") {} + + PDNode* operator()(PDNode* x); + + // Operators + PATTERN_DECL_NODE(gru); + + // Inputs + PATTERN_DECL_NODE(Bias); + PATTERN_DECL_NODE(Weight); + + // Outputs + PATTERN_DECL_NODE(BatchGate); + PATTERN_DECL_NODE(BatchResetHiddenPrev); + PATTERN_DECL_NODE(BatchHidden); + PATTERN_DECL_NODE(Hidden); +}; } // namespace patterns +// Link two ir::Nodes from each other. #define IR_NODE_LINK_TO(a, b) \ a->outputs.push_back(b); \ b->inputs.push_back(a); diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc index e1a441d09..a7d5161c3 100644 --- a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -192,6 +192,8 @@ std::unique_ptr SeqConcatFcFusePass::ApplyImpl( auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); + int fuse_count{0}; + detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "get one concat pattern"; @@ -239,8 +241,12 @@ std::unique_ptr SeqConcatFcFusePass::ApplyImpl( marked_nodes.erase(sequence_expand1_in); marked_nodes.erase(fc_out); GraphSafeRemoveNodes(graph, marked_nodes); + + ++fuse_count; }); + AddStatis(fuse_count); + return graph; } diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index a496ae41a..dc1b03b2d 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -267,6 +267,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, PADDLE_ENFORCE(config.ir_mode == AnalysisConfig::IrPassMode::kExclude); // default config.ir_passes.clear(); // Do not exclude any pass. + int batch_size = FLAGS_batch_size; int num_times = FLAGS_repeat; @@ -346,6 +347,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ASSERT_TRUE(fuse_statis.count("fc_fuse")); EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM + EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1); EXPECT_EQ(num_ops, 13); // After graph optimization, only 13 operators exists. } -- GitLab