From 6945a80b55da2f81dc997ea5f3adadd1c38616bd Mon Sep 17 00:00:00 2001 From: Wilber Date: Mon, 17 Feb 2020 10:00:34 +0800 Subject: [PATCH] cherry-pick 22551. test=develop test=release/1.7 (#22609) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [cherry-pick] #22551 当一个模型中有多个fc_lstm子图的时候,且其中fc共用了同一个persistable的bias,此时不应该将bias节点删除,只将非persistable的节点去除即可。 --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 7 +- .../framework/ir/fc_gru_fuse_pass_tester.cc | 116 ++++++++++++++++++ .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 10 +- .../framework/ir/fc_lstm_fuse_pass_tester.cc | 116 ++++++++++++++++++ .../fluid/framework/ir/pass_tester_helper.h | 56 +++++++++ 6 files changed, 301 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.cc create mode 100644 paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index fe2e1ab30d..13800b5007 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -116,6 +116,8 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +cc_test(test_fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto) +cc_test(test_fc_gru_fuse_pass SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto) diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 283c544889..d26998e6fc 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -127,8 +127,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern); // nodes need be removed GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern); - GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern); - GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchResetHiddenPrev, + gru_pattern); + GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchHidden, gru_pattern); if (with_fc_bias) { GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); @@ -138,7 +139,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate, + {mul, gru, elementwise_add, fc_out, mul_out, BatchGate, BatchResetHiddenPrev, BatchHidden}); GraphSafeRemoveNodes(graph, marked_nodes); } else { diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.cc new file mode 100644 index 0000000000..70351b8aaf --- /dev/null +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "gru_fc_w", {}); + AddVarToScope(param_scope, "gru_fc_b", {}); + AddVarToScope(param_scope, "gru_w", {}); + AddVarToScope(param_scope, "gru_b", {}); + AddVarToScope(param_scope, "gru_batch_gate_0", {}); + AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {}); + AddVarToScope(param_scope, "gru_batch_hidden_0", {}); + AddVarToScope(param_scope, "gru_hidden_0", {}); + AddVarToScope(param_scope, "gru_batch_gate_1", {}); + AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {}); + AddVarToScope(param_scope, "gru_batch_hidden_1", {}); + AddVarToScope(param_scope, "gru_hidden_1", {}); + return param_scope; +} + +TEST(FCFusePass, basic) { + // inputs operator output + // -------------------------------------------------------- + // (a, gru_fc_w) mul -> fc_0_tmp_0 + // (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1 + // (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0 + + // (b, gru_fc_w) mul -> fc_1_tmp_0 + // (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1 + // (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* b = layers.data("b"); + auto* fc_w = layers.data("gru_fc_w", {}, true); + auto* fc_b = layers.data("gru_fc_b", {}, true); + auto* gru_w = layers.data("gru_w", {}, true); + auto* gru_b = layers.data("gru_b", {}, true); + auto* fc_0_tmp0 = layers.mul(a, fc_w); + auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b); + auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false); + auto* gru_batch_reset_hidden_prev_0 = + layers.data("gru_batch_reset_hidden_prev_0", {}, false); + auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false); + auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false); + layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0, + gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0); + + auto* fc_1_tmp0 = layers.mul(b, fc_w); + auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b); + auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false); + auto* gru_batch_reset_hidden_prev_1 = + layers.data("gru_batch_reset_hidden_prev_1", {}, false); + auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false); + auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false); + layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1, + gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("fc_gru_fuse_pass"); + pass->Set("use_gpu", new bool(true)); + graph->Set("__param_scope__", CreateParamScope()); + int num_nodes_before = graph->Nodes().size(); + int num_gru_nodes_before = GetNumOpNodes(graph, "gru"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fuse_gru_nodes_after = GetNumOpNodes(graph, "fusion_gru"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6, + platform::errors::PreconditionNotMet( + "The number of nodes before and after " + "the fuse does not meet expectations")); + PADDLE_ENFORCE_EQ( + num_fuse_gru_nodes_after, 2, + platform::errors::PreconditionNotMet("The number of gru nodes before the " + "fuse does not meet expectations")); + PADDLE_ENFORCE_EQ(num_gru_nodes_before, num_fuse_gru_nodes_after, + platform::errors::PreconditionNotMet( + "The number of fusion_gru nodes does not meet " + "expectations after fuse")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fc_gru_fuse_pass); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 89fa5a75e9..44306a7295 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -133,26 +133,30 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); - GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(BatchCellPreAct, BatchCellPreAct, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, lstm_pattern); + GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); if (with_fc_bias) { GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, fc_bias); // Remove unneeded nodes. std::unordered_set marked_nodes( - {mul, lstm, elementwise_add, fc_bias}); + {mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct}); GraphSafeRemoveNodes(graph, marked_nodes); } else { GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, nullptr); // Remove unneeded nodes. - std::unordered_set marked_nodes({mul, lstm}); + std::unordered_set marked_nodes( + {mul, lstm, BatchGate, BatchCellPreAct}); GraphSafeRemoveNodes(graph, marked_nodes); } diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.cc new file mode 100644 index 0000000000..0de8d4684f --- /dev/null +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" + +#include +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "lstm_fc_w", {}); + AddVarToScope(param_scope, "lstm_fc_b", {}); + AddVarToScope(param_scope, "lstm_w", {}); + AddVarToScope(param_scope, "lstm_b", {}); + AddVarToScope(param_scope, "lstm_cell_0", {}); + AddVarToScope(param_scope, "lstm_batch_gate_0", {}); + AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_0", {}); + AddVarToScope(param_scope, "lstm_hidden_0", {}); + AddVarToScope(param_scope, "lstm_cell_1", {}); + AddVarToScope(param_scope, "lstm_batch_gate_1", {}); + AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_1", {}); + AddVarToScope(param_scope, "lstm_hidden_1", {}); + return param_scope; +} + +TEST(FCLSTMFusePass, basic) { + // inputs operator output + // -------------------------------------------------------- + // (a, lstm_fc_w) mul -> fc_0_tmp_0 + // (fc_0_tmp_0, lstm_fc_b) elementwise_add -> fc_0_tmp_1 + // fc_0_tmp_1,lstm_w,lstm_b lstm -> lstm_out_0 + + // (b, lstm_fc_w) mul -> fc_1_tmp_0 + // (fc_1_tmp_0, lstm_fc_b) elementwise_add -> fc_1_tmp_1 + // (fc_1_tmp_1,lstm_w,lstm_b) lstm -> lstm_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* b = layers.data("b"); + auto* fc_w = layers.data("lstm_fc_w", {}, true); + auto* fc_b = layers.data("lstm_fc_b", {}, true); + auto* lstm_w = layers.data("lstm_w", {}, true); + auto* lstm_b = layers.data("lstm_b", {}, true); + auto* fc_0_tmp0 = layers.mul(a, fc_w); + auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b); + auto* lstm_cell_0 = layers.data("lstm_cell_0", {}, false); + auto* lstm_batch_gate_0 = layers.data("lstm_batch_gate_0", {}, false); + auto* lstm_batch_cell_pre_gate_0 = + layers.data("lstm_batch_cell_pre_gate_0", {}, false); + auto* lstm_hidden_0 = layers.data("lstm_hidden_0", {}, false); + layers.lstm(fc_0_tmp1, lstm_w, lstm_b, lstm_cell_0, lstm_batch_gate_0, + lstm_hidden_0, lstm_batch_cell_pre_gate_0); + + auto* fc_1_tmp0 = layers.mul(b, fc_w); + auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b); + auto* lstm_cell_1 = layers.data("lstm_cell_1", {}, false); + auto* lstm_batch_gate_1 = layers.data("lstm_batch_gate_1", {}, false); + auto* lstm_batch_cell_pre_gate_1 = + layers.data("lstm_batch_cell_pre_gate_1", {}, false); + auto* lstm_hidden_1 = layers.data("lstm_hidden_1", {}, false); + layers.lstm(fc_1_tmp1, lstm_w, lstm_b, lstm_cell_1, lstm_batch_gate_1, + lstm_hidden_1, lstm_batch_cell_pre_gate_1); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("fc_lstm_fuse_pass"); + pass->Set("use_gpu", new bool(false)); + graph->Set("__param_scope__", CreateParamScope()); + int num_nodes_before = graph->Nodes().size(); + int num_lstm_nodes_before = GetNumOpNodes(graph, "lstm"); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fusion_lstm_nodes_after = GetNumOpNodes(graph, "fusion_lstm"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after - 6, + platform::errors::PreconditionNotMet( + "The number of nodes before and after " + "the fuse does not meet expectations")); + PADDLE_ENFORCE_EQ(num_fusion_lstm_nodes_after, 2, + platform::errors::PreconditionNotMet( + "The number of lstm nodes before the " + "fuse does not meet expectations")); + PADDLE_ENFORCE_EQ(num_lstm_nodes_before, num_fusion_lstm_nodes_after, + platform::errors::PreconditionNotMet( + "The number of fusion_gru nodes does " + "not meet expectations after fuse")); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fc_lstm_fuse_pass); diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 3c74612552..82f9e72661 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -120,6 +120,62 @@ struct Layers { return out; } + void lstm(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* cell, + VarDesc* batch_gate, VarDesc* hidden, VarDesc* batch_cell_pre_act, + VarDesc* h0 = nullptr, VarDesc* c0 = nullptr, + bool use_peepholes = true, bool is_reverse = false, + std::string gate_activation = "sigmoid", + std::string cell_activation = "tanh", + std::string candidate_activation = "tanh") { + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("lstm"); + op->SetInput("Input", {input->Name()}); + op->SetInput("Weight", {w->Name()}); + op->SetInput("Bias", {bias->Name()}); + if (h0) { + op->SetInput("H0", {h0->Name()}); + } + if (c0) { + op->SetInput("C0", {c0->Name()}); + } + op->SetOutput("Hidden", {hidden->Name()}); + op->SetOutput("Cell", {cell->Name()}); + op->SetOutput("BatchGate", {batch_gate->Name()}); + op->SetOutput("BatchCellPreAct", {batch_cell_pre_act->Name()}); + op->SetAttr("use_peepholes", use_peepholes); + op->SetAttr("is_reverse", is_reverse); + op->SetAttr("gate_activation", gate_activation); + op->SetAttr("cell_activation", cell_activation); + op->SetAttr("candidate_activation", candidate_activation); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + } + + void gru(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* batch_gate, + VarDesc* batch_reset_hidden_prev, VarDesc* batch_hidden, + VarDesc* hidden, VarDesc* h0 = nullptr, bool origin_mode = false, + bool is_reverse = false, std::string activation = "tanh", + std::string gate_activation = "sigmoid") { + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("gru"); + op->SetInput("Input", {input->Name()}); + op->SetInput("Weight", {w->Name()}); + op->SetInput("Bias", {bias->Name()}); + if (h0) { + op->SetInput("H0", {h0->Name()}); + } + op->SetOutput("BatchGate", {batch_gate->Name()}); + op->SetOutput("BatchResetHiddenPrev", {batch_reset_hidden_prev->Name()}); + op->SetOutput("BatchHidden", {batch_hidden->Name()}); + op->SetOutput("Hidden", {hidden->Name()}); + op->SetAttr("origin_mode", origin_mode); + op->SetAttr("is_reverse", is_reverse); + op->SetAttr("activation", activation); + op->SetAttr("gate_activation", gate_activation); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + } + VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, int x_num_col_dims = 1) { AttributeMap attrs; -- GitLab