diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e7b49620c0102c213ecf336d9324ee9585bcb7c --- /dev/null +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -0,0 +1,193 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h" +#include +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::string GenNodeName(const std::string& prefix, const std::string& name) { + return prefix + "/" + name; +} + +void BuildPattern(PDPattern* pattern, const std::string& name_scope, + bool with_fc_bias) { + PDNode* x = pattern->NewNode(name_scope, "x") + ->assert_is_op_input("mul") + ->assert_var_not_persistable(); + auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias); + fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. + patterns::GRU(pattern, name_scope, fc_out); + VLOG(3) << "\n" << pattern->DotString(); +} + +int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, + bool with_fc_bias) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + BuildPattern(pattern, name_scope, with_fc_bias); + + // Create New OpDesc + auto gru_creater = [&](int gru, int x, int weight_x, int weight_h, int bias, + int hidden, int fc_bias) { +#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); + GET_NODE(x); + GET_NODE(weight_x); + GET_NODE(weight_h); + GET_NODE(bias); + GET_NODE(hidden); + GET_NODE(gru); + + OpDesc op_desc; + op_desc.SetType("fusion_gru"); +#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); + SET_IN(X, x); + SET_IN(WeightX, weight_x); + SET_IN(WeightH, weight_h); + SET_IN(Bias, bias); +#undef SET_IN + if (with_fc_bias) { + // Add FC-bias with LSTM-bias and create a new weight + PADDLE_ENFORCE(scope); + const std::string& new_bias_var = name_scope + "_bias.new"; + auto* bias_var = scope->Var(new_bias_var); + PADDLE_ENFORCE(bias_var); + auto* bias_tensor = bias_var->GetMutable(); + auto* gru_bias_var = scope->FindVar(bias_n->Name()); + PADDLE_ENFORCE(gru_bias_var); + const auto& gru_bias_tenosr = gru_bias_var->Get(); + bias_tensor->Resize(gru_bias_tenosr.dims()); + + GET_NODE(fc_bias); + auto* fc_bias_var = scope->FindVar(fc_bias_n->Name()); + const auto& fc_bias_tensor = fc_bias_var->Get(); + // new bias = fc bias + gru bias + auto* data = bias_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < bias_tensor->numel(); i++) { + data[i] = + fc_bias_tensor.data()[i] + gru_bias_tenosr.data()[i]; + } + op_desc.SetInput("Bias", {new_bias_var}); + } +#undef GET_NODE + + op_desc.SetInput("H0", {}); + op_desc.SetOutput("Hidden", {hidden_n->Name()}); + op_desc.SetAttr("is_reverse", gru_n->Op()->GetAttr("is_reverse")); + // TODO(TJ): This should be a option for infer + op_desc.SetAttr("use_seq", true); + + // Create temp variables. + // TODO(TJ): clean code + scope->Var(name_scope + "/ReorderedH0.new") + ->GetMutable(); + scope->Var(name_scope + "/XX.new")->GetMutable(); + scope->Var(name_scope + "/BatchedInput.new") + ->GetMutable(); + scope->Var(name_scope + "/BatchedOut.new") + ->GetMutable(); + op_desc.SetOutput("ReorderedH0", {name_scope + "/ReorderedH0.new"}); + op_desc.SetOutput("XX", {name_scope + "/XX.new"}); + op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"}); + op_desc.SetOutput("BatchedOut", {name_scope + "/BatchedOut.new"}); + + auto* op = graph->CreateOpNode(&op_desc); + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); + auto* scope = graph->Get(kParamScopeAttr); + + IR_NODE_LINK_TO(x_n, op); + IR_NODE_LINK_TO(weight_x_n, op); + IR_NODE_LINK_TO(weight_h_n, op); + IR_NODE_LINK_TO(bias_n, op); + IR_NODE_LINK_TO(op, hidden_n); + // h0? + return op; + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { +#define GET_NODE(name__) \ + std::string name__##key = name_scope + "/" + #name__; \ + auto* name__##n = pattern->RetrieveNode(name__##key); \ + PADDLE_ENFORCE(name__##n); \ + PADDLE_ENFORCE(subgraph.count(name__##n)); \ + Node* name__##_n = subgraph.at(name__##n); \ + int name__ __attribute__((unused)) = name__##_n->id(); + + GET_NODE(x); + GET_NODE(w); + GET_NODE(mul); + GET_NODE(fc_out); + GET_NODE(Weight); + GET_NODE(gru); + GET_NODE(Bias); + GET_NODE(Hidden); + + if (with_fc_bias) { + GET_NODE(fc_bias); + GET_NODE(elementwise_add); + gru_creater(gru, x, w, Weight, Bias, Hidden, fc_bias); + // Remove unneeded nodes. + std::unordered_set marked_nodes( + {mul_n, gru_n, elementwise_add_n}); + GraphSafeRemoveNodes(graph, marked_nodes); + } else { + gru_creater(gru, x, w, Weight, Bias, Hidden, -1); + // Remove unneeded nodes. + std::unordered_set marked_nodes({mul_n, gru_n}); + GraphSafeRemoveNodes(graph, marked_nodes); + } +#undef GET_NODE + + ++fusion_count; + }; + + gpd(graph, handler); + + return fusion_count; +} + +std::unique_ptr MulGRUFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + + int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), + false /*with_fc_bias*/); + + AddStatis(fusion_count); + return graph; +} + +std::unique_ptr FCGRUFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(name_scope_, graph.get()); + + int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(), + true /*with_fc_bias*/); + + AddStatis(fusion_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulGRUFusePass); +REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCGRUFusePass); diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.h b/paddle/fluid/framework/ir/fc_gru_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..63e1c72bfb2e2641ae5d44858b342d5e427e9045 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.h @@ -0,0 +1,50 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op. + +class FCGRUFusePass : public FusePassBase { + public: + virtual ~FCGRUFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + + const std::string name_scope_{"fc_gru_fuse"}; +}; + +// Just FC without bias +class MulGRUFusePass : public FusePassBase { + public: + virtual ~MulGRUFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; + const std::string name_scope_{"fc_nobias_gru_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 434bee4ccee1c199088d09c934fe86435ec7d095..8dfe36f78195933f4c3867ebe28140b4376ed572 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -565,6 +565,7 @@ PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, return fc_out; } + PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x) { x->assert_is_op_input("lstm", "Input"); @@ -589,6 +590,32 @@ PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope, lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct}); return Hidden; } + +PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, + PDNode* x) { + x->assert_is_op_input("gru", "Input"); + auto* gru_op = pattern->NewNode(name_scope, "gru")->assert_is_op("gru"); +#define NEW_NODE(arg__, io__) \ + auto* arg__ = pattern->NewNode(name_scope, #arg__) \ + ->assert_is_op_##io__("gru", #arg__); + + NEW_NODE(Weight, input); + // TODO(Superjomn): upgrade the fuse framework to support optional. + // H0 and bias are optional + NEW_NODE(Bias, input); // also optional + // NEW_NODE(H0, input); + + NEW_NODE(Hidden, output); + // below are intermediate + NEW_NODE(BatchGate, output); + NEW_NODE(BatchResetHiddenPrev, output); + NEW_NODE(BatchHidden, output); + + gru_op->LinksFrom({x, Weight, Bias}); + gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden}); + return Hidden; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index eacea1750f6f1e86a8fe79637c3bd757a7275398..71e4c36d9b6327ff419179ca7ed10332f448e245 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -298,6 +298,8 @@ PDNode* FC(PDPattern* pattern, const std::string& name_scope, PDNode* x, PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x); +PDNode* GRU(PDPattern* pattern, const std::string& name_scope, PDNode* x); + } // namespace patterns #define IR_NODE_LINK_TO(a, b) \