From 7b5a8e46decf5cf35d6f9ff29fb24ff6bc0e79b9 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 25 Nov 2020 03:25:50 +0100 Subject: [PATCH] Add multi_gru_fuse_pass and tests (#28601) * Add multi_gru_fuse_pass and tests * fix date * cleaned up headers --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/graph_pattern_detector.cc | 51 ++++++ .../framework/ir/graph_pattern_detector.h | 23 +++ .../ir/mkldnn/multi_gru_fuse_pass.cc | 123 ++++++++++++++ .../framework/ir/mkldnn/multi_gru_fuse_pass.h | 42 +++++ .../ir/mkldnn/multi_gru_fuse_pass_tester.cc | 156 ++++++++++++++++++ .../ir/mkldnn/multi_gru_seq_fuse_pass.cc | 10 +- tools/static_mode_white_list.py | 1 + 8 files changed, 403 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 1455f8a099..e1f9a236b7 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -111,6 +111,7 @@ if(WITH_MKLDNN) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) + pass_library(multi_gru_fuse_pass inference DIR mkldnn) pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) endif() @@ -170,5 +171,6 @@ endif() cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) + cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass) cc_test(test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass) endif () diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 2fb506da39..e163f6c352 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2511,6 +2511,57 @@ PDNode *patterns::FusionGru::operator()() { return out; } +PDNode *patterns::TwoFusionGruConcat::operator()() { + auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input( + "fusion_gru", "X"); + auto gru1 = + pattern->NewNode(gru1_repr()) + ->assert_is_op("fusion_gru") + ->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("is_reverse") == false; + }); + auto gru2 = + pattern->NewNode(gru2_repr()) + ->assert_is_op("fusion_gru") + ->assert_more([&](Node *node) { + return node->Op()->GetAttrIfExists("is_reverse") == true; + }); + auto wh1 = pattern->NewNode(wh1_repr()) + ->AsInput() + ->assert_is_op_input("fusion_gru", "WeightH"); + auto wh2 = pattern->NewNode(wh2_repr()) + ->AsInput() + ->assert_is_op_input("fusion_gru", "WeightH"); + auto wx1 = pattern->NewNode(wx1_repr()) + ->AsInput() + ->assert_is_op_input("fusion_gru", "WeightX"); + auto wx2 = pattern->NewNode(wx2_repr()) + ->AsInput() + ->assert_is_op_input("fusion_gru", "WeightX"); + auto b1 = pattern->NewNode(b1_repr())->AsInput()->assert_is_op_input( + "fusion_gru", "Bias"); + auto b2 = pattern->NewNode(b2_repr())->AsInput()->assert_is_op_input( + "fusion_gru", "Bias"); + auto h1 = pattern->NewNode(h1_repr()) + ->AsOutput() + ->assert_is_op_output("fusion_gru", "Hidden") + ->assert_is_op_input("concat") + ->AsIntermediate(); + auto h2 = pattern->NewNode(h2_repr()) + ->AsOutput() + ->assert_is_op_output("fusion_gru", "Hidden") + ->assert_is_op_input("concat") + ->AsIntermediate(); + auto concat = pattern->NewNode(concat_repr())->assert_is_op("concat"); + auto out = pattern->NewNode(out_repr()) + ->AsOutput() + ->assert_is_op_output("concat", "Out"); + gru1->LinksFrom({x, wh1, wx1, b1}).LinksTo({h1}); + gru2->LinksFrom({x, wh2, wx2, b2}).LinksTo({h2}); + concat->LinksFrom({h1, h2}).LinksTo({out}); + return out; +} + PDNode *patterns::MultiGruSeq::operator()() { auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input( "multi_gru", "X"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 28782b2965..a4e8d916e5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1420,6 +1420,29 @@ struct FusionGru : public PatternBase { PATTERN_DECL_NODE(out); }; +// two concatenated fusion_gru ops +// Forward pass for fusion of two concatenated fusion_gru ops. +// concat_out is a result of the operator(). +struct TwoFusionGruConcat : public PatternBase { + TwoFusionGruConcat(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "bi_fusion_gru") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(gru1); + PATTERN_DECL_NODE(gru2); + PATTERN_DECL_NODE(wh1); + PATTERN_DECL_NODE(wh2); + PATTERN_DECL_NODE(wx1); + PATTERN_DECL_NODE(wx2); + PATTERN_DECL_NODE(b1); + PATTERN_DECL_NODE(b2); + PATTERN_DECL_NODE(h1); + PATTERN_DECL_NODE(h2); + PATTERN_DECL_NODE(concat); + PATTERN_DECL_NODE(out); +}; + // two subsequent bi_fusion_gru ops // Forward pass for fusion of two subsequent fusion_gru ops. // Hidden of the last fusion_gru op is a result of the operator(). diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc new file mode 100644 index 0000000000..43c9849d5b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h" +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using EigenVectorArrayMap = Eigen::Map>; +using string::PrettyLogDetail; + +namespace { + +std::vector JoinInputs(Node* op1, Node* op2, + std::string input_name) { + auto in1 = op1->Op()->Input(input_name); + auto& in2 = op2->Op()->Input(input_name); + in1.insert(in1.end(), in2.begin(), in2.end()); + return in1; +} + +} // namespace + +void MultiGRUFusePass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Fusing two concatenated multi_gru ops."; + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument cannot be NULL.")); + FusePassBase::Init(name_scope_, graph); + PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument( + "Scope cannot be nullptr.")); + + GraphPatternDetector gpd; + patterns::TwoFusionGruConcat pattern{gpd.mutable_pattern(), name_scope_}; + pattern(); + + int fused_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern); + GET_IR_NODE_FROM_SUBGRAPH(gru1, gru1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(gru2, gru2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wh1, wh1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wh2, wh2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wx1, wx1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wx2, wx2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(b1, b1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(b2, b2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(h1, h1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(h2, h2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(concat, concat, pattern); + GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern); + + if (gru1->Op()->GetAttrIfExists("origin_mode") != + gru2->Op()->GetAttrIfExists("origin_mode")) { + LOG(INFO) << "The two fusion_gru ops have different values of the " + "origin_mode attribute. Skipping fuse."; + return; + } + + auto wx = JoinInputs(gru1, gru2, "WeightX"); + auto wh = JoinInputs(gru1, gru2, "WeightH"); + auto b = JoinInputs(gru1, gru2, "Bias"); + + OpDesc multi_gru_desc; + multi_gru_desc.SetType("multi_gru"); + multi_gru_desc.SetInput("X", std::vector({x->Name()})); + multi_gru_desc.SetInput("WeightX", wx); + multi_gru_desc.SetInput("WeightH", wh); + multi_gru_desc.SetInput("Bias", b); + multi_gru_desc.SetOutput("Hidden", std::vector({out->Name()})); + + auto attrs_to_skip = {"is_reverse", "use_seq"}; + for (auto& attr : gru1->Op()->GetAttrMap()) { + if (std::find(attrs_to_skip.begin(), attrs_to_skip.end(), attr.first) == + attrs_to_skip.end()) + multi_gru_desc.SetAttr(attr.first, attr.second); + } + multi_gru_desc.SetAttr("layers", 1); + auto multi_gru = + g->CreateOpNode(&multi_gru_desc); // OpDesc will be copied. + + IR_NODE_LINK_TO(x, multi_gru); + IR_NODE_LINK_TO(b1, multi_gru); + IR_NODE_LINK_TO(b2, multi_gru); + IR_NODE_LINK_TO(wh1, multi_gru); + IR_NODE_LINK_TO(wh2, multi_gru); + IR_NODE_LINK_TO(wx1, multi_gru); + IR_NODE_LINK_TO(wx2, multi_gru); + IR_NODE_LINK_TO(multi_gru, out); + GraphSafeRemoveNodes(graph, {gru1, gru2, h1, h2, concat}); + + ++fused_count; + }; + gpd(graph, handler); + AddStatis(fused_count); + + PrettyLogDetail("--- fused %d pairs of concatenated multi_gru ops", + fused_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_gru_fuse_pass, paddle::framework::ir::MultiGRUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h new file mode 100644 index 0000000000..70f88104b4 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h @@ -0,0 +1,42 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" + +namespace paddle { +namespace framework { +namespace ir { + +// This pass fuses two concatenated fusion_gru ops into a single multi_gru op. +// It turns +// a -> fusion_gru -> c -> concat -> e +// \> fusion_gru -> d / +// into +// a -> multi_gru -> e +class MultiGRUFusePass : public FusePassBase { + public: + virtual ~MultiGRUFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + const std::string name_scope_{"multi_gru"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc new file mode 100644 index 0000000000..7b6681ff96 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass_tester.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/multi_gru_fuse_pass.h" +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs, bool is_reverse = false, + bool origin_mode = false) { + auto* op = prog->MutableBlock(0)->AppendOp(); + + op->SetType(type); + if (type == "fusion_gru") { + op->SetInput("X", {inputs[0]}); + op->SetInput("WeightX", {inputs[1]}); + op->SetInput("WeightH", {inputs[2]}); + op->SetInput("Bias", {inputs[3]}); + op->SetOutput("Hidden", {outputs[0]}); + op->SetAttr("is_reverse", is_reverse); + op->SetAttr("origin_mode", origin_mode); + } else if (type == "concat") { + op->SetInput("X", {inputs[0], inputs[1]}); + op->SetOutput("Out", {outputs[0]}); + } else { + FAIL() << "Unexpected operator type."; + } +} + +static const std::initializer_list variable_names = { + "x", "wx1", "wx2", "wh1", "wh2", "b1", "b2", "h1", "h2", "out"}; + +// (x, wx1, wh1, b1) -> fusion_gru1 -> h1 +// (x, wx2, wh2, b2) -> fusion_gru2 -> h2 +// (h1, h2) -> concat -> out +ProgramDesc BuildProgramDesc(bool origin_mode1, bool origin_mode2) { + ProgramDesc prog; + + for (auto& v : variable_names) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "fusion_gru", {"x", "wx1", "wh1", "b1"}, {"h1"}, false, + origin_mode1); + SetOp(&prog, "fusion_gru", {"x", "wx2", "wh2", "b2"}, {"h2"}, true, + origin_mode2); + SetOp(&prog, "concat", {"h1", "h2"}, {"out"}); + return prog; +} + +void MainTest(const ProgramDesc& prog, int removed_nodes_count, + int added_nodes_count, + const std::vector multi_gru_inputs, + const std::string multi_gru_output, bool origin_mode) { + // Apply pass + std::unique_ptr graph(new ir::Graph(prog)); + Scope scope; + graph->SetNotOwned(kParamScopeAttr, &scope); + int original_nodes_num = graph->Nodes().size(); + auto pass = PassRegistry::Instance().Get("multi_gru_fuse_pass"); + graph.reset(pass->Apply(graph.release())); + int current_nodes_num = graph->Nodes().size(); + + // Verify graph after fuse + int count_multi_gru = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "multi_gru") { + EXPECT_EQ(op->Input("X")[0], multi_gru_inputs[0]); + EXPECT_EQ(op->Input("WeightX").size(), 2u); + EXPECT_EQ(op->Input("WeightX")[0], multi_gru_inputs[1]); + EXPECT_EQ(op->Input("WeightX")[1], multi_gru_inputs[2]); + EXPECT_EQ(op->Input("WeightH").size(), 2u); + EXPECT_EQ(op->Input("WeightH")[0], multi_gru_inputs[3]); + EXPECT_EQ(op->Input("WeightH")[1], multi_gru_inputs[4]); + EXPECT_EQ(op->Input("Bias").size(), 2u); + EXPECT_EQ(op->Input("Bias")[0], multi_gru_inputs[5]); + EXPECT_EQ(op->Input("Bias")[1], multi_gru_inputs[6]); + EXPECT_EQ(op->Output("Hidden")[0], multi_gru_output); + EXPECT_EQ(op->GetAttrIfExists("layers"), 1); + EXPECT_EQ(op->GetAttrIfExists("origin_mode"), origin_mode); + ++count_multi_gru; + } + } + } + EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count, + current_nodes_num); + EXPECT_EQ(count_multi_gru, added_nodes_count); +} + +TEST(MultiGruFusePass, same_origin_modes_1) { + bool origin_mode1 = false; + bool origin_mode2 = false; + + // nodes to be removed: 2x fusion_gru + 2x hidden(output) + concat + const int removed_nodes_count = 5; + // nodes to be added: multi_gru + const int added_nodes_count = 1; + + const std::initializer_list multi_gru_inputs = { + "x", "wx1", "wx2", "wh1", "wh2", "b1", "b2"}; + MainTest(BuildProgramDesc(origin_mode1, origin_mode2), removed_nodes_count, + added_nodes_count, multi_gru_inputs, "out", origin_mode1); +} + +TEST(MultiGruFusePass, same_origin_modes_2) { + bool origin_mode1 = true; + bool origin_mode2 = true; + + // nodes to be removed: 2x fusion_gru + 2x hidden(output) + concat + const int removed_nodes_count = 5; + // nodes to be added: multi_gru + const int added_nodes_count = 1; + + const std::initializer_list multi_gru_inputs = { + "x", "wx1", "wx2", "wh1", "wh2", "b1", "b2"}; + MainTest(BuildProgramDesc(origin_mode1, origin_mode2), removed_nodes_count, + added_nodes_count, multi_gru_inputs, "out", origin_mode1); +} + +TEST(MultiGruFusePass, different_origin_modes) { + bool origin_mode1 = true; + bool origin_mode2 = false; + + // the fuse should not be applied, so + // nodes to be removed: none + const int removed_nodes_count = 0; + // nodes to be added: none + const int added_nodes_count = 0; + + const std::initializer_list multi_gru_inputs = { + "x", "wx1", "wx2", "wh1", "wh2", "b1", "b2"}; + MainTest(BuildProgramDesc(origin_mode1, origin_mode2), removed_nodes_count, + added_nodes_count, multi_gru_inputs, "out", origin_mode1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(multi_gru_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc index 105f812898..17770d26d7 100644 --- a/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc @@ -32,8 +32,8 @@ using string::PrettyLogDetail; namespace { -std::vector join_inputs(Node* op1, Node* op2, - std::string input_name) { +std::vector JoinInputs(Node* op1, Node* op2, + std::string input_name) { auto in1 = op1->Op()->Input(input_name); auto& in2 = op2->Op()->Input(input_name); in1.insert(in1.end(), in2.begin(), in2.end()); @@ -83,9 +83,9 @@ void MultiGruSeqFusePass::ApplyImpl(ir::Graph* graph) const { return; } - auto wx = join_inputs(gru1, gru2, "WeightX"); - auto wh = join_inputs(gru1, gru2, "WeightH"); - auto b = join_inputs(gru1, gru2, "Bias"); + auto wx = JoinInputs(gru1, gru2, "WeightX"); + auto wh = JoinInputs(gru1, gru2, "WeightH"); + auto b = JoinInputs(gru1, gru2, "Bias"); OpDesc multi_gru_desc; multi_gru_desc.SetType("multi_gru"); diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 2824fddc8f..b6e8203aa7 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -603,6 +603,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_matmul_bf16_mkldnn_op', 'test_mul_int8_mkldnn_op', 'test_multi_gru_mkldnn_op', + 'test_multi_gru_fuse_pass', 'test_multi_gru_seq_fuse_pass', 'test_pool2d_int8_mkldnn_op', 'test_pool2d_mkldnn_op', -- GitLab