未验证 提交 991345b3 编写于 作者: W Wojciech Uss 提交者: GitHub

Add multi_gru_seq_fuse_pass and tests (#28604)

* Add multi_gru_seq_fuse_pass and tests

* fix date

* removed unused functions
上级 83cee3c9
...@@ -111,6 +111,7 @@ if(WITH_MKLDNN) ...@@ -111,6 +111,7 @@ if(WITH_MKLDNN)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn) pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn) pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
endif() endif()
cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector )
...@@ -169,4 +170,5 @@ endif() ...@@ -169,4 +170,5 @@ endif()
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass) cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass)
endif () endif ()
...@@ -2511,6 +2511,59 @@ PDNode *patterns::FusionGru::operator()() { ...@@ -2511,6 +2511,59 @@ PDNode *patterns::FusionGru::operator()() {
return out; return out;
} }
PDNode *patterns::MultiGruSeq::operator()() {
auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input(
"multi_gru", "X");
auto gru1 = pattern->NewNode(gru1_repr())->assert_is_op("multi_gru");
auto wx11 = pattern->NewNode(wx11_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightX", 0);
auto wx12 = pattern->NewNode(wx12_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightX", 1);
auto wh11 = pattern->NewNode(wh11_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightH", 0);
auto wh12 = pattern->NewNode(wh12_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightH", 1);
auto b11 = pattern->NewNode(b11_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "Bias", 0);
auto b12 = pattern->NewNode(b12_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "Bias", 1);
auto h1 = pattern->NewNode(h1_repr())
->AsOutput()
->assert_is_op_output("multi_gru", "Hidden")
->assert_is_op_input("multi_gru", "X")
->AsIntermediate();
auto gru2 = pattern->NewNode(gru2_repr())->assert_is_op("multi_gru");
auto wx21 = pattern->NewNode(wx21_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightX", 0);
auto wx22 = pattern->NewNode(wx22_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightX", 1);
auto wh21 = pattern->NewNode(wh21_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightH", 0);
auto wh22 = pattern->NewNode(wh22_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "WeightH", 1);
auto b21 = pattern->NewNode(b21_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "Bias", 0);
auto b22 = pattern->NewNode(b22_repr())
->AsInput()
->assert_is_op_nth_input("multi_gru", "Bias", 1);
auto h2 = pattern->NewNode(h2_repr())->AsOutput()->assert_is_op_output(
"multi_gru", "Hidden");
gru1->LinksFrom({x, wx11, wx12, wh11, wh12, b11, b12}).LinksTo({h1});
gru2->LinksFrom({h1, wx21, wx22, wh21, wh22, b21, b22}).LinksTo({h2});
return h2;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1420,6 +1420,33 @@ struct FusionGru : public PatternBase { ...@@ -1420,6 +1420,33 @@ struct FusionGru : public PatternBase {
PATTERN_DECL_NODE(out); PATTERN_DECL_NODE(out);
}; };
// two subsequent bi_fusion_gru ops
// Forward pass for fusion of two subsequent fusion_gru ops.
// Hidden of the last fusion_gru op is a result of the operator().
struct MultiGruSeq : public PatternBase {
MultiGruSeq(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multi_gru_seq") {}
PDNode* operator()();
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(gru1);
PATTERN_DECL_NODE(wx11);
PATTERN_DECL_NODE(wx12);
PATTERN_DECL_NODE(wh11);
PATTERN_DECL_NODE(wh12);
PATTERN_DECL_NODE(b11);
PATTERN_DECL_NODE(b12);
PATTERN_DECL_NODE(h1);
PATTERN_DECL_NODE(gru2);
PATTERN_DECL_NODE(wx21);
PATTERN_DECL_NODE(wx22);
PATTERN_DECL_NODE(wh21);
PATTERN_DECL_NODE(wh22);
PATTERN_DECL_NODE(b21);
PATTERN_DECL_NODE(b22);
PATTERN_DECL_NODE(h2);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h"
#include <limits>
#include <sstream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<double, Eigen::Dynamic, 1>>;
using string::PrettyLogDetail;
namespace {
std::vector<std::string> join_inputs(Node* op1, Node* op2,
std::string input_name) {
auto in1 = op1->Op()->Input(input_name);
auto& in2 = op2->Op()->Input(input_name);
in1.insert(in1.end(), in2.begin(), in2.end());
return in1;
}
} // namespace
void MultiGruSeqFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Fusing two consecutive multi_gru ops.";
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument cannot be NULL."));
FusePassBase::Init(name_scope_, graph);
PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument(
"Scope cannot be nullptr."));
GraphPatternDetector gpd;
patterns::MultiGruSeq pattern{gpd.mutable_pattern(), name_scope_};
pattern();
int fused_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
GET_IR_NODE_FROM_SUBGRAPH(gru1, gru1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wx11, wx11, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wx12, wx12, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wh11, wh11, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wh12, wh12, pattern);
GET_IR_NODE_FROM_SUBGRAPH(b11, b11, pattern);
GET_IR_NODE_FROM_SUBGRAPH(b12, b12, pattern);
GET_IR_NODE_FROM_SUBGRAPH(h1, h1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(gru2, gru2, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wx21, wx21, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wx22, wx22, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wh21, wh21, pattern);
GET_IR_NODE_FROM_SUBGRAPH(wh22, wh22, pattern);
GET_IR_NODE_FROM_SUBGRAPH(b21, b21, pattern);
GET_IR_NODE_FROM_SUBGRAPH(b22, b22, pattern);
GET_IR_NODE_FROM_SUBGRAPH(h2, h2, pattern);
if (gru1->Op()->GetAttrIfExists<bool>("origin_mode") !=
gru2->Op()->GetAttrIfExists<bool>("origin_mode")) {
LOG(INFO) << "The two multi_gru ops have different values of the "
"origin_mode attribute. Skipping fuse.";
return;
}
auto wx = join_inputs(gru1, gru2, "WeightX");
auto wh = join_inputs(gru1, gru2, "WeightH");
auto b = join_inputs(gru1, gru2, "Bias");
OpDesc multi_gru_desc;
multi_gru_desc.SetType("multi_gru");
multi_gru_desc.SetInput("X", std::vector<std::string>({x->Name()}));
multi_gru_desc.SetInput("WeightX", wx);
multi_gru_desc.SetInput("WeightH", wh);
multi_gru_desc.SetInput("Bias", b);
multi_gru_desc.SetOutput("Hidden", std::vector<std::string>({h2->Name()}));
for (auto& attr : gru1->Op()->GetAttrMap()) {
multi_gru_desc.SetAttr(attr.first, attr.second);
}
auto layers = BOOST_GET_CONST(int, gru1->Op()->GetAttr("layers")) +
BOOST_GET_CONST(int, gru2->Op()->GetAttr("layers"));
multi_gru_desc.SetAttr("layers", layers);
auto multi_gru =
g->CreateOpNode(&multi_gru_desc); // OpDesc will be copied.
IR_NODE_LINK_TO(x, multi_gru);
IR_NODE_LINK_TO(wx11, multi_gru);
IR_NODE_LINK_TO(wx12, multi_gru);
IR_NODE_LINK_TO(wx21, multi_gru);
IR_NODE_LINK_TO(wx22, multi_gru);
IR_NODE_LINK_TO(wh11, multi_gru);
IR_NODE_LINK_TO(wh12, multi_gru);
IR_NODE_LINK_TO(wh21, multi_gru);
IR_NODE_LINK_TO(wh22, multi_gru);
IR_NODE_LINK_TO(b11, multi_gru);
IR_NODE_LINK_TO(b12, multi_gru);
IR_NODE_LINK_TO(b21, multi_gru);
IR_NODE_LINK_TO(b22, multi_gru);
IR_NODE_LINK_TO(multi_gru, h2);
GraphSafeRemoveNodes(graph, {gru1, gru2, h1});
++fused_count;
};
gpd(graph, handler);
AddStatis(fused_count);
PrettyLogDetail("--- fused %d sequences of two multi_gru ops",
fused_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_gru_seq_fuse_pass,
paddle::framework::ir::MultiGruSeqFusePass);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class MultiGruSeqFusePass : public FusePassBase {
public:
virtual ~MultiGruSeqFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"multi_gru_seq"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/multi_gru_seq_fuse_pass.h"
#include <gtest/gtest.h>
#include <initializer_list>
namespace paddle {
namespace framework {
namespace ir {
const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
const std::string& prefix,
int number) {
auto v = std::vector<std::string>();
for (int i = 0; i < number; ++i) {
auto name = prefix + std::to_string(i);
prog->MutableBlock(0)->Var(name);
v.push_back(name);
}
return v;
}
void create_vars(ProgramDesc* prog,
const std::initializer_list<std::string>& names) {
for (auto name : names) prog->MutableBlock(0)->Var(name);
}
void SetMultiGruOp(ProgramDesc* prog, const std::string x,
const std::vector<std::string> wx,
const std::vector<std::string> wh,
const std::vector<std::string> b, const std::string h,
int layers, bool origin_mode) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType("multi_gru");
op->SetInput("X", {x});
op->SetInput("WeightX", wx);
op->SetInput("WeightH", wh);
op->SetInput("Bias", b);
op->SetOutput("Hidden", {h});
op->SetAttr("layers", layers);
op->SetAttr("origin_mode", origin_mode);
}
// (x, wx1, wh1, b1) -> multi_gru1 -> h1
// (h1, wx2, wh2, b2) -> multi_gru2 -> h2
void MainTest(int layers1, int layers2, bool origin_mode1, bool origin_mode2) {
ProgramDesc prog;
// Create variables
create_vars(&prog, {"x", "h1", "h2"});
const std::vector<std::string> wx1 =
churn_out_vars(&prog, "wx1", 2 * layers1);
const std::vector<std::string> wx2 =
churn_out_vars(&prog, "wx2", 2 * layers2);
const std::vector<std::string> wh1 =
churn_out_vars(&prog, "wh1", 2 * layers1);
const std::vector<std::string> wh2 =
churn_out_vars(&prog, "wh2", 2 * layers2);
const std::vector<std::string> b1 = churn_out_vars(&prog, "b1", 2 * layers1);
const std::vector<std::string> b2 = churn_out_vars(&prog, "b2", 2 * layers2);
// Create program descriptor
SetMultiGruOp(&prog, "x", wx1, wh1, b1, "h1", layers1, origin_mode1);
SetMultiGruOp(&prog, "h1", wx2, wh2, b2, "h2", layers2, origin_mode2);
// Apply pass
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
Scope scope;
graph->SetNotOwned(kParamScopeAttr, &scope);
int original_nodes_num = graph->Nodes().size();
auto pass = PassRegistry::Instance().Get("multi_gru_seq_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
// Verify graph after fuse
bool should_fuse = origin_mode1 == origin_mode2;
int count_multi_gru = 0;
auto layers = layers1;
auto wx = wx1;
auto wh = wh1;
auto b = b1;
auto h = "h1";
if (should_fuse) {
layers += layers2;
wx.insert(wx.end(), wx2.begin(), wx2.end());
wh.insert(wh.end(), wh2.begin(), wh2.end());
b.insert(b.end(), b2.begin(), b2.end());
h = "h2";
}
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "multi_gru") {
if (op->Input("X")[0] == "x") {
EXPECT_EQ(op->GetAttrIfExists<int>("layers"), layers);
EXPECT_EQ(op->Input("WeightX").size(), 2u * layers);
EXPECT_EQ(op->Input("WeightH").size(), 2u * layers);
EXPECT_EQ(op->Input("Bias").size(), 2u * layers);
for (int i = 0; i < 2 * layers; ++i) {
EXPECT_EQ(op->Input("WeightX")[i], wx[i]);
EXPECT_EQ(op->Input("WeightH")[i], wh[i]);
EXPECT_EQ(op->Input("Bias")[i], b[i]);
}
EXPECT_EQ(op->Output("Hidden")[0], h);
EXPECT_EQ(op->GetAttrIfExists<bool>("origin_mode"), origin_mode1);
} else {
EXPECT_EQ(op->GetAttrIfExists<int>("layers"), layers2);
EXPECT_EQ(op->Input("X")[0], "h1");
EXPECT_EQ(op->Input("WeightX").size(), 2u * layers2);
EXPECT_EQ(op->Input("WeightH").size(), 2u * layers2);
EXPECT_EQ(op->Input("Bias").size(), 2u * layers2);
for (int i = 0; i < 2 * layers2; ++i) {
EXPECT_EQ(op->Input("WeightX")[i], wx2[i]);
EXPECT_EQ(op->Input("WeightH")[i], wh2[i]);
EXPECT_EQ(op->Input("Bias")[i], b2[i]);
}
EXPECT_EQ(op->Output("Hidden")[0], "h2");
EXPECT_EQ(op->GetAttrIfExists<bool>("origin_mode"), origin_mode2);
}
++count_multi_gru;
}
}
}
// If the fuse is applied, then:
// nodes to be removed: 2x multi_gru + 1x hidden(output)
// nodes to be added: multi_gru
// If the fuse is not applied, then:
// nodes to be removed: none
// nodes to be added: none
const int removed_nodes_count = should_fuse ? 3 : 0;
const int added_nodes_count = should_fuse ? 1 : 0;
EXPECT_EQ(original_nodes_num - removed_nodes_count + added_nodes_count,
current_nodes_num);
EXPECT_EQ(count_multi_gru, should_fuse ? 1 : 2);
}
TEST(MultiGruSeqFusePass, same_origin_modes_1) {
int layers1 = 1;
int layers2 = 1;
bool origin_mode1 = false;
bool origin_mode2 = false;
MainTest(layers1, layers2, origin_mode1, origin_mode2);
}
TEST(MultiGruSeqFusePass, same_origin_modes_2) {
int layers1 = 2;
int layers2 = 3;
bool origin_mode1 = false;
bool origin_mode2 = false;
MainTest(layers1, layers2, origin_mode1, origin_mode2);
}
TEST(MultiGruSeqFusePass, same_origin_modes_3) {
int layers1 = 2;
int layers2 = 1;
bool origin_mode1 = true;
bool origin_mode2 = true;
MainTest(layers1, layers2, origin_mode1, origin_mode2);
}
TEST(MultiGruSeqFusePass, different_origin_modes) {
int layers1 = 2;
int layers2 = 2;
bool origin_mode1 = true;
bool origin_mode2 = false;
MainTest(layers1, layers2, origin_mode1, origin_mode2);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(multi_gru_seq_fuse_pass);
...@@ -603,6 +603,7 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -603,6 +603,7 @@ STATIC_MODE_TESTING_LIST = [
'test_matmul_bf16_mkldnn_op', 'test_matmul_bf16_mkldnn_op',
'test_mul_int8_mkldnn_op', 'test_mul_int8_mkldnn_op',
'test_multi_gru_mkldnn_op', 'test_multi_gru_mkldnn_op',
'test_multi_gru_seq_fuse_pass',
'test_pool2d_int8_mkldnn_op', 'test_pool2d_int8_mkldnn_op',
'test_pool2d_mkldnn_op', 'test_pool2d_mkldnn_op',
'test_quantize_mkldnn_op', 'test_quantize_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册