From 991345b368142fd4ce60ce3cdfb8b93228cbde87 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Tue, 24 Nov 2020 04:59:26 +0100 Subject: [PATCH] Add multi_gru_seq_fuse_pass and tests (#28604) * Add multi_gru_seq_fuse_pass and tests * fix date * removed unused functions --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/graph_pattern_detector.cc | 53 +++++ .../framework/ir/graph_pattern_detector.h | 27 +++ .../ir/mkldnn/multi_gru_seq_fuse_pass.cc | 139 +++++++++++++ .../ir/mkldnn/multi_gru_seq_fuse_pass.h | 40 ++++ .../mkldnn/multi_gru_seq_fuse_pass_tester.cc | 187 ++++++++++++++++++ tools/static_mode_white_list.py | 1 + 7 files changed, 449 insertions(+) create mode 100644 paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index f2f7e16ff2..1455f8a099 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -111,6 +111,7 @@ if(WITH_MKLDNN) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) + pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn) endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) @@ -169,4 +170,5 @@ endif() cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) + cc_test(test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass) endif () diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 56dacdc6db..2fb506da39 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2511,6 +2511,59 @@ PDNode *patterns::FusionGru::operator()() { return out; } +PDNode *patterns::MultiGruSeq::operator()() { + auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input( + "multi_gru", "X"); + auto gru1 = pattern->NewNode(gru1_repr())->assert_is_op("multi_gru"); + auto wx11 = pattern->NewNode(wx11_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightX", 0); + auto wx12 = pattern->NewNode(wx12_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightX", 1); + auto wh11 = pattern->NewNode(wh11_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightH", 0); + auto wh12 = pattern->NewNode(wh12_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightH", 1); + auto b11 = pattern->NewNode(b11_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "Bias", 0); + auto b12 = pattern->NewNode(b12_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "Bias", 1); + auto h1 = pattern->NewNode(h1_repr()) + ->AsOutput() + ->assert_is_op_output("multi_gru", "Hidden") + ->assert_is_op_input("multi_gru", "X") + ->AsIntermediate(); + auto gru2 = pattern->NewNode(gru2_repr())->assert_is_op("multi_gru"); + auto wx21 = pattern->NewNode(wx21_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightX", 0); + auto wx22 = pattern->NewNode(wx22_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightX", 1); + auto wh21 = pattern->NewNode(wh21_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightH", 0); + auto wh22 = pattern->NewNode(wh22_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "WeightH", 1); + auto b21 = pattern->NewNode(b21_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "Bias", 0); + auto b22 = pattern->NewNode(b22_repr()) + ->AsInput() + ->assert_is_op_nth_input("multi_gru", "Bias", 1); + auto h2 = pattern->NewNode(h2_repr())->AsOutput()->assert_is_op_output( + "multi_gru", "Hidden"); + gru1->LinksFrom({x, wx11, wx12, wh11, wh12, b11, b12}).LinksTo({h1}); + gru2->LinksFrom({h1, wx21, wx22, wh21, wh22, b21, b22}).LinksTo({h2}); + return h2; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 77a1b03407..28782b2965 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1420,6 +1420,33 @@ struct FusionGru : public PatternBase { PATTERN_DECL_NODE(out); }; +// two subsequent bi_fusion_gru ops +// Forward pass for fusion of two subsequent fusion_gru ops. +// Hidden of the last fusion_gru op is a result of the operator(). +struct MultiGruSeq : public PatternBase { + MultiGruSeq(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "multi_gru_seq") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(gru1); + PATTERN_DECL_NODE(wx11); + PATTERN_DECL_NODE(wx12); + PATTERN_DECL_NODE(wh11); + PATTERN_DECL_NODE(wh12); + PATTERN_DECL_NODE(b11); + PATTERN_DECL_NODE(b12); + PATTERN_DECL_NODE(h1); + PATTERN_DECL_NODE(gru2); + PATTERN_DECL_NODE(wx21); + PATTERN_DECL_NODE(wx22); + PATTERN_DECL_NODE(wh21); + PATTERN_DECL_NODE(wh22); + PATTERN_DECL_NODE(b21); + PATTERN_DECL_NODE(b22); + PATTERN_DECL_NODE(h2); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc new file mode 100644 index 0000000000..105f812898 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/string/pretty_log.h" + +namespace paddle { +namespace framework { +namespace ir { + +using EigenVectorArrayMap = Eigen::Map>; +using string::PrettyLogDetail; + +namespace { + +std::vector join_inputs(Node* op1, Node* op2, + std::string input_name) { + auto in1 = op1->Op()->Input(input_name); + auto& in2 = op2->Op()->Input(input_name); + in1.insert(in1.end(), in2.begin(), in2.end()); + return in1; +} + +} // namespace + +void MultiGruSeqFusePass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Fusing two consecutive multi_gru ops."; + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument cannot be NULL.")); + FusePassBase::Init(name_scope_, graph); + PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument( + "Scope cannot be nullptr.")); + + GraphPatternDetector gpd; + patterns::MultiGruSeq pattern{gpd.mutable_pattern(), name_scope_}; + pattern(); + + int fused_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern); + GET_IR_NODE_FROM_SUBGRAPH(gru1, gru1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wx11, wx11, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wx12, wx12, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wh11, wh11, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wh12, wh12, pattern); + GET_IR_NODE_FROM_SUBGRAPH(b11, b11, pattern); + GET_IR_NODE_FROM_SUBGRAPH(b12, b12, pattern); + GET_IR_NODE_FROM_SUBGRAPH(h1, h1, pattern); + GET_IR_NODE_FROM_SUBGRAPH(gru2, gru2, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wx21, wx21, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wx22, wx22, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wh21, wh21, pattern); + GET_IR_NODE_FROM_SUBGRAPH(wh22, wh22, pattern); + GET_IR_NODE_FROM_SUBGRAPH(b21, b21, pattern); + GET_IR_NODE_FROM_SUBGRAPH(b22, b22, pattern); + GET_IR_NODE_FROM_SUBGRAPH(h2, h2, pattern); + + if (gru1->Op()->GetAttrIfExists("origin_mode") != + gru2->Op()->GetAttrIfExists("origin_mode")) { + LOG(INFO) << "The two multi_gru ops have different values of the " + "origin_mode attribute. Skipping fuse."; + return; + } + + auto wx = join_inputs(gru1, gru2, "WeightX"); + auto wh = join_inputs(gru1, gru2, "WeightH"); + auto b = join_inputs(gru1, gru2, "Bias"); + + OpDesc multi_gru_desc; + multi_gru_desc.SetType("multi_gru"); + multi_gru_desc.SetInput("X", std::vector({x->Name()})); + multi_gru_desc.SetInput("WeightX", wx); + multi_gru_desc.SetInput("WeightH", wh); + multi_gru_desc.SetInput("Bias", b); + multi_gru_desc.SetOutput("Hidden", std::vector({h2->Name()})); + + for (auto& attr : gru1->Op()->GetAttrMap()) { + multi_gru_desc.SetAttr(attr.first, attr.second); + } + + auto layers = BOOST_GET_CONST(int, gru1->Op()->GetAttr("layers")) + + BOOST_GET_CONST(int, gru2->Op()->GetAttr("layers")); + multi_gru_desc.SetAttr("layers", layers); + + auto multi_gru = + g->CreateOpNode(&multi_gru_desc); // OpDesc will be copied. + + IR_NODE_LINK_TO(x, multi_gru); + IR_NODE_LINK_TO(wx11, multi_gru); + IR_NODE_LINK_TO(wx12, multi_gru); + IR_NODE_LINK_TO(wx21, multi_gru); + IR_NODE_LINK_TO(wx22, multi_gru); + IR_NODE_LINK_TO(wh11, multi_gru); + IR_NODE_LINK_TO(wh12, multi_gru); + IR_NODE_LINK_TO(wh21, multi_gru); + IR_NODE_LINK_TO(wh22, multi_gru); + IR_NODE_LINK_TO(b11, multi_gru); + IR_NODE_LINK_TO(b12, multi_gru); + IR_NODE_LINK_TO(b21, multi_gru); + IR_NODE_LINK_TO(b22, multi_gru); + IR_NODE_LINK_TO(multi_gru, h2); + GraphSafeRemoveNodes(graph, {gru1, gru2, h1}); + + ++fused_count; + }; + gpd(graph, handler); + AddStatis(fused_count); + + PrettyLogDetail("--- fused %d sequences of two multi_gru ops", + fused_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_gru_seq_fuse_pass, + paddle::framework::ir::MultiGruSeqFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h new file mode 100644 index 0000000000..546a3d6570 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h @@ -0,0 +1,40 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class MultiGruSeqFusePass : public FusePassBase { + public: + virtual ~MultiGruSeqFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + const std::string name_scope_{"multi_gru_seq"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass_tester.cc new file mode 100644 index 0000000000..3738e3ebd6 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass_tester.cc @@ -0,0 +1,187 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h" +#include +#include + +namespace paddle { +namespace framework { +namespace ir { + +const std::vector churn_out_vars(ProgramDesc* prog, + const std::string& prefix, + int number) { + auto v = std::vector(); + for (int i = 0; i < number; ++i) { + auto name = prefix + std::to_string(i); + prog->MutableBlock(0)->Var(name); + v.push_back(name); + } + return v; +} + +void create_vars(ProgramDesc* prog, + const std::initializer_list& names) { + for (auto name : names) prog->MutableBlock(0)->Var(name); +} + +void SetMultiGruOp(ProgramDesc* prog, const std::string x, + const std::vector wx, + const std::vector wh, + const std::vector b, const std::string h, + int layers, bool origin_mode) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType("multi_gru"); + op->SetInput("X", {x}); + op->SetInput("WeightX", wx); + op->SetInput("WeightH", wh); + op->SetInput("Bias", b); + op->SetOutput("Hidden", {h}); + op->SetAttr("layers", layers); + op->SetAttr("origin_mode", origin_mode); +} + +// (x, wx1, wh1, b1) -> multi_gru1 -> h1 +// (h1, wx2, wh2, b2) -> multi_gru2 -> h2 +void MainTest(int layers1, int layers2, bool origin_mode1, bool origin_mode2) { + ProgramDesc prog; + + // Create variables + create_vars(&prog, {"x", "h1", "h2"}); + const std::vector wx1 = + churn_out_vars(&prog, "wx1", 2 * layers1); + const std::vector wx2 = + churn_out_vars(&prog, "wx2", 2 * layers2); + const std::vector wh1 = + churn_out_vars(&prog, "wh1", 2 * layers1); + const std::vector wh2 = + churn_out_vars(&prog, "wh2", 2 * layers2); + const std::vector b1 = churn_out_vars(&prog, "b1", 2 * layers1); + const std::vector b2 = churn_out_vars(&prog, "b2", 2 * layers2); + + // Create program descriptor + SetMultiGruOp(&prog, "x", wx1, wh1, b1, "h1", layers1, origin_mode1); + SetMultiGruOp(&prog, "h1", wx2, wh2, b2, "h2", layers2, origin_mode2); + + // Apply pass + std::unique_ptr graph(new ir::Graph(prog)); + Scope scope; + graph->SetNotOwned(kParamScopeAttr, &scope); + int original_nodes_num = graph->Nodes().size(); + auto pass = PassRegistry::Instance().Get("multi_gru_seq_fuse_pass"); + graph.reset(pass->Apply(graph.release())); + int current_nodes_num = graph->Nodes().size(); + + // Verify graph after fuse + bool should_fuse = origin_mode1 == origin_mode2; + int count_multi_gru = 0; + auto layers = layers1; + auto wx = wx1; + auto wh = wh1; + auto b = b1; + auto h = "h1"; + if (should_fuse) { + layers += layers2; + wx.insert(wx.end(), wx2.begin(), wx2.end()); + wh.insert(wh.end(), wh2.begin(), wh2.end()); + b.insert(b.end(), b2.begin(), b2.end()); + h = "h2"; + } + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "multi_gru") { + if (op->Input("X")[0] == "x") { + EXPECT_EQ(op->GetAttrIfExists("layers"), layers); + EXPECT_EQ(op->Input("WeightX").size(), 2u * layers); + EXPECT_EQ(op->Input("WeightH").size(), 2u * layers); + EXPECT_EQ(op->Input("Bias").size(), 2u * layers); + for (int i = 0; i < 2 * layers; ++i) { + EXPECT_EQ(op->Input("WeightX")[i], wx[i]); + EXPECT_EQ(op->Input("WeightH")[i], wh[i]); + EXPECT_EQ(op->Input("Bias")[i], b[i]); + } + EXPECT_EQ(op->Output("Hidden")[0], h); + EXPECT_EQ(op->GetAttrIfExists("origin_mode"), origin_mode1); + } else { + EXPECT_EQ(op->GetAttrIfExists("layers"), layers2); + EXPECT_EQ(op->Input("X")[0], "h1"); + EXPECT_EQ(op->Input("WeightX").size(), 2u * layers2); + EXPECT_EQ(op->Input("WeightH").size(), 2u * layers2); + EXPECT_EQ(op->Input("Bias").size(), 2u * layers2); + for (int i = 0; i < 2 * layers2; ++i) { + EXPECT_EQ(op->Input("WeightX")[i], wx2[i]); + EXPECT_EQ(op->Input("WeightH")[i], wh2[i]); + EXPECT_EQ(op->Input("Bias")[i], b2[i]); + } + EXPECT_EQ(op->Output("Hidden")[0], "h2"); + EXPECT_EQ(op->GetAttrIfExists("origin_mode"), origin_mode2); + } + ++count_multi_gru; + } + } + } + + // If the fuse is applied, then: + // nodes to be removed: 2x multi_gru + 1x hidden(output) + // nodes to be added: multi_gru + // If the fuse is not applied, then: + // nodes to be removed: none + // nodes to be added: none + const int removed_nodes_count = should_fuse ? 3 : 0; + const int added_nodes_count = should_fuse ? 1 : 0; + + EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count, + current_nodes_num); + EXPECT_EQ(count_multi_gru, should_fuse ? 1 : 2); +} + +TEST(MultiGruSeqFusePass, same_origin_modes_1) { + int layers1 = 1; + int layers2 = 1; + bool origin_mode1 = false; + bool origin_mode2 = false; + MainTest(layers1, layers2, origin_mode1, origin_mode2); +} + +TEST(MultiGruSeqFusePass, same_origin_modes_2) { + int layers1 = 2; + int layers2 = 3; + bool origin_mode1 = false; + bool origin_mode2 = false; + MainTest(layers1, layers2, origin_mode1, origin_mode2); +} + +TEST(MultiGruSeqFusePass, same_origin_modes_3) { + int layers1 = 2; + int layers2 = 1; + bool origin_mode1 = true; + bool origin_mode2 = true; + MainTest(layers1, layers2, origin_mode1, origin_mode2); +} + +TEST(MultiGruSeqFusePass, different_origin_modes) { + int layers1 = 2; + int layers2 = 2; + bool origin_mode1 = true; + bool origin_mode2 = false; + MainTest(layers1, layers2, origin_mode1, origin_mode2); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(multi_gru_seq_fuse_pass); diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 544c79fb13..2824fddc8f 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -603,6 +603,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_matmul_bf16_mkldnn_op', 'test_mul_int8_mkldnn_op', 'test_multi_gru_mkldnn_op', + 'test_multi_gru_seq_fuse_pass', 'test_pool2d_int8_mkldnn_op', 'test_pool2d_mkldnn_op', 'test_quantize_mkldnn_op', -- GitLab