未验证 提交 cd95ea82 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Small fixes related to BF16 fusion_gru and fusion_lstm (#33295)

* Small changes related to BF16 fusion_gru and fusion_lstm

* Correct to pass arg by value

* Add conditions to rnn op

* Correct the spelling mistake

* Improving the test with checking activation

* Trigger CI
上级 abc17ef7
...@@ -188,4 +188,6 @@ endif() ...@@ -188,4 +188,6 @@ endif()
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass) cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass)
cc_test(test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass) cc_test(test_multi_gru_seq_fuse_pass SRCS mkldnn/multi_gru_seq_fuse_pass_tester.cc DEPS multi_gru_seq_fuse_pass)
set(TEST_FC_RNN_PASS_DEPS fc_gru_fuse_pass fc_lstm_fuse_pass mkldnn_placement_pass)
cc_test(test_fc_rnn_mkldnn_fuse_pass SRCS mkldnn/mkldnn_fc_rnn_fuse_pass_tester.cc DEPS ${TEST_FC_RNN_PASS_DEPS})
endif () endif ()
...@@ -47,8 +47,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -47,8 +47,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
gru_pattern(fc_out); gru_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h, auto gru_creator = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
Node* bias, Node* hidden, Node* fc_bias) { Node* bias, Node* hidden, Node* fc_bias,
const bool use_mkldnn) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_gru"); op_desc.SetType("fusion_gru");
...@@ -67,6 +68,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -67,6 +68,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
gru->Op()->GetAttrIfExists<bool>("origin_mode")); gru->Op()->GetAttrIfExists<bool>("origin_mode"));
// TODO(TJ): This should be a option for infer // TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
op_desc.SetAttr("use_mkldnn", use_mkldnn);
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation")); op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation")); op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));
...@@ -149,6 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -149,6 +151,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True."; LOG(INFO) << "fc_gru_fuse_pass not supported when origin_mode=True.";
return; return;
} }
const bool use_mkldnn =
(mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
gru->Op()->GetAttrIfExists<std::string>("activation") == "tanh" &&
gru->Op()->GetAttrIfExists<std::string>("gate_activation") ==
"sigmoid");
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
...@@ -156,14 +163,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -156,14 +163,14 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias); gru_creator(gru, x_n, w, Weight, Bias, Hidden, fc_bias, use_mkldnn);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, gru, elementwise_add, fc_out, mul_out, BatchGate, {mul, gru, elementwise_add, fc_out, mul_out, BatchGate,
BatchResetHiddenPrev, BatchHidden}); BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
gru_creater(gru, x_n, w, Weight, Bias, Hidden, nullptr); gru_creator(gru, x_n, w, Weight, Bias, Hidden, nullptr, use_mkldnn);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden}); {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});
......
...@@ -12,77 +12,15 @@ ...@@ -12,77 +12,15 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name, namespace fc_gru_test {
const DDim& dims) { TEST(FcGruFusePass, basic) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>(); std::unique_ptr<ir::Graph> graph = PrepareGraph();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "gru_fc_w", {});
AddVarToScope(param_scope, "gru_fc_b", {});
AddVarToScope(param_scope, "gru_w", {});
AddVarToScope(param_scope, "gru_b", {});
AddVarToScope(param_scope, "gru_batch_gate_0", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {});
AddVarToScope(param_scope, "gru_batch_hidden_0", {});
AddVarToScope(param_scope, "gru_hidden_0", {});
AddVarToScope(param_scope, "gru_batch_gate_1", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {});
AddVarToScope(param_scope, "gru_batch_hidden_1", {});
AddVarToScope(param_scope, "gru_hidden_1", {});
return param_scope;
}
TEST(FCFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
// (a, gru_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1
// (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0
// (b, gru_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("gru_fc_w", {}, true);
auto* fc_b = layers.data("gru_fc_b", {}, true);
auto* gru_w = layers.data("gru_w", {}, true);
auto* gru_b = layers.data("gru_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false);
auto* gru_batch_reset_hidden_prev_0 =
layers.data("gru_batch_reset_hidden_prev_0", {}, false);
auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false);
auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false);
layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0,
gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0);
auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false);
auto* gru_batch_reset_hidden_prev_1 =
layers.data("gru_batch_reset_hidden_prev_1", {}, false);
auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false);
auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false);
layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1,
gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_gru_fuse_pass"); auto pass = PassRegistry::Instance().Get("fc_gru_fuse_pass");
pass->Set("use_gpu", new bool(true)); pass->Set("use_gpu", new bool(true));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -109,6 +47,7 @@ TEST(FCFusePass, basic) { ...@@ -109,6 +47,7 @@ TEST(FCFusePass, basic) {
"expectations after fuse")); "expectations after fuse"));
} }
} // namespace fc_gru_test
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fc_gru_test {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "gru_fc_w", {});
AddVarToScope(param_scope, "gru_fc_b", {});
AddVarToScope(param_scope, "gru_w", {});
AddVarToScope(param_scope, "gru_b", {});
AddVarToScope(param_scope, "gru_batch_gate_0", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {});
AddVarToScope(param_scope, "gru_batch_hidden_0", {});
AddVarToScope(param_scope, "gru_hidden_0", {});
AddVarToScope(param_scope, "gru_batch_gate_1", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {});
AddVarToScope(param_scope, "gru_batch_hidden_1", {});
AddVarToScope(param_scope, "gru_hidden_1", {});
return param_scope;
}
std::unique_ptr<ir::Graph> PrepareGraph(
std::string activation = "tanh", std::string gate_activation = "sigmoid") {
// inputs operator output
// --------------------------------------------------------
// (a, gru_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1
// (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0
// (b, gru_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("gru_fc_w", {}, true);
auto* fc_b = layers.data("gru_fc_b", {}, true);
auto* gru_w = layers.data("gru_w", {}, true);
auto* gru_b = layers.data("gru_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false);
auto* gru_batch_reset_hidden_prev_0 =
layers.data("gru_batch_reset_hidden_prev_0", {}, false);
auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false);
auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false);
layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0,
gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0,
nullptr, false, false, activation, gate_activation);
auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false);
auto* gru_batch_reset_hidden_prev_1 =
layers.data("gru_batch_reset_hidden_prev_1", {}, false);
auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false);
auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false);
layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1,
gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1,
nullptr, false, false, activation, gate_activation);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
return std::move(graph);
}
} // namespace fc_gru_test
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -47,7 +47,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -47,7 +47,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
// Create New OpDesc // Create New OpDesc
auto lstm_creator = [&](Node* lstm, Node* input, Node* weight_x, auto lstm_creator = [&](Node* lstm, Node* input, Node* weight_x,
Node* weight_h, Node* bias, Node* hidden, Node* cell, Node* weight_h, Node* bias, Node* hidden, Node* cell,
Node* xx, Node* fc_bias) { Node* xx, Node* fc_bias, const bool use_mkldnn) {
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_lstm"); op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()}); #define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
...@@ -88,6 +88,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -88,6 +88,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc.SetOutput("XX", {xx->Name()}); op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
op_desc.SetAttr("use_mkldnn", use_mkldnn);
// TODO(TJ): get from attr // TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
...@@ -148,13 +149,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -148,13 +149,22 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
const bool use_mkldnn =
(mul->Op()->GetAttrIfExists<bool>("use_mkldnn") &&
lstm->Op()->GetAttrIfExists<std::string>("gate_activation") ==
"sigmoid" &&
lstm->Op()->GetAttrIfExists<std::string>("cell_activation") ==
"tanh" &&
lstm->Op()->GetAttrIfExists<std::string>("candidate_activation") ==
"tanh");
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
fc_bias); fc_bias, use_mkldnn);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct}); {mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct});
...@@ -162,7 +172,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -162,7 +172,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
} else { } else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
nullptr); nullptr, use_mkldnn);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, lstm, BatchGate, BatchCellPreAct}); {mul, lstm, BatchGate, BatchCellPreAct});
......
...@@ -12,77 +12,16 @@ ...@@ -12,77 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name, namespace fc_lstm_test {
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "lstm_fc_w", {});
AddVarToScope(param_scope, "lstm_fc_b", {});
AddVarToScope(param_scope, "lstm_w", {});
AddVarToScope(param_scope, "lstm_b", {});
AddVarToScope(param_scope, "lstm_cell_0", {});
AddVarToScope(param_scope, "lstm_batch_gate_0", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_0", {});
AddVarToScope(param_scope, "lstm_hidden_0", {});
AddVarToScope(param_scope, "lstm_cell_1", {});
AddVarToScope(param_scope, "lstm_batch_gate_1", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_1", {});
AddVarToScope(param_scope, "lstm_hidden_1", {});
return param_scope;
}
TEST(FCLSTMFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
// (a, lstm_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, lstm_fc_b) elementwise_add -> fc_0_tmp_1
// fc_0_tmp_1,lstm_w,lstm_b lstm -> lstm_out_0
// (b, lstm_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, lstm_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,lstm_w,lstm_b) lstm -> lstm_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("lstm_fc_w", {}, true);
auto* fc_b = layers.data("lstm_fc_b", {}, true);
auto* lstm_w = layers.data("lstm_w", {}, true);
auto* lstm_b = layers.data("lstm_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* lstm_cell_0 = layers.data("lstm_cell_0", {}, false);
auto* lstm_batch_gate_0 = layers.data("lstm_batch_gate_0", {}, false);
auto* lstm_batch_cell_pre_gate_0 =
layers.data("lstm_batch_cell_pre_gate_0", {}, false);
auto* lstm_hidden_0 = layers.data("lstm_hidden_0", {}, false);
layers.lstm(fc_0_tmp1, lstm_w, lstm_b, lstm_cell_0, lstm_batch_gate_0,
lstm_hidden_0, lstm_batch_cell_pre_gate_0);
auto* fc_1_tmp0 = layers.mul(b, fc_w); TEST(FcLstmFusePass, basic) {
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b); std::unique_ptr<ir::Graph> graph = PrepareGraph();
auto* lstm_cell_1 = layers.data("lstm_cell_1", {}, false);
auto* lstm_batch_gate_1 = layers.data("lstm_batch_gate_1", {}, false);
auto* lstm_batch_cell_pre_gate_1 =
layers.data("lstm_batch_cell_pre_gate_1", {}, false);
auto* lstm_hidden_1 = layers.data("lstm_hidden_1", {}, false);
layers.lstm(fc_1_tmp1, lstm_w, lstm_b, lstm_cell_1, lstm_batch_gate_1,
lstm_hidden_1, lstm_batch_cell_pre_gate_1);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_lstm_fuse_pass"); auto pass = PassRegistry::Instance().Get("fc_lstm_fuse_pass");
pass->Set("use_gpu", new bool(false)); pass->Set("use_gpu", new bool(false));
graph->Set("__param_scope__", CreateParamScope()); graph->Set("__param_scope__", CreateParamScope());
...@@ -108,7 +47,7 @@ TEST(FCLSTMFusePass, basic) { ...@@ -108,7 +47,7 @@ TEST(FCLSTMFusePass, basic) {
"The number of fusion_gru nodes does " "The number of fusion_gru nodes does "
"not meet expectations after fuse")); "not meet expectations after fuse"));
} }
} // namespace fc_lstm_test
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
namespace fc_lstm_test {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "lstm_fc_w", {});
AddVarToScope(param_scope, "lstm_fc_b", {});
AddVarToScope(param_scope, "lstm_w", {});
AddVarToScope(param_scope, "lstm_b", {});
AddVarToScope(param_scope, "lstm_cell_0", {});
AddVarToScope(param_scope, "lstm_batch_gate_0", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_0", {});
AddVarToScope(param_scope, "lstm_hidden_0", {});
AddVarToScope(param_scope, "lstm_cell_1", {});
AddVarToScope(param_scope, "lstm_batch_gate_1", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_1", {});
AddVarToScope(param_scope, "lstm_hidden_1", {});
return param_scope;
}
std::unique_ptr<ir::Graph> PrepareGraph(
std::string gate_activation = "sigmoid",
std::string cell_activation = "tanh",
std::string candidate_activation = "tanh") {
// inputs operator output
// --------------------------------------------------------
// (a, lstm_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, lstm_fc_b) elementwise_add -> fc_0_tmp_1
// fc_0_tmp_1,lstm_w,lstm_b lstm -> lstm_out_0
// (b, lstm_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, lstm_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,lstm_w,lstm_b) lstm -> lstm_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("lstm_fc_w", {}, true);
auto* fc_b = layers.data("lstm_fc_b", {}, true);
auto* lstm_w = layers.data("lstm_w", {}, true);
auto* lstm_b = layers.data("lstm_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* lstm_cell_0 = layers.data("lstm_cell_0", {}, false);
auto* lstm_batch_gate_0 = layers.data("lstm_batch_gate_0", {}, false);
auto* lstm_batch_cell_pre_gate_0 =
layers.data("lstm_batch_cell_pre_gate_0", {}, false);
auto* lstm_hidden_0 = layers.data("lstm_hidden_0", {}, false);
layers.lstm(fc_0_tmp1, lstm_w, lstm_b, lstm_cell_0, lstm_batch_gate_0,
lstm_hidden_0, lstm_batch_cell_pre_gate_0, nullptr, nullptr, true,
false, gate_activation, cell_activation, candidate_activation);
auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* lstm_cell_1 = layers.data("lstm_cell_1", {}, false);
auto* lstm_batch_gate_1 = layers.data("lstm_batch_gate_1", {}, false);
auto* lstm_batch_cell_pre_gate_1 =
layers.data("lstm_batch_cell_pre_gate_1", {}, false);
auto* lstm_hidden_1 = layers.data("lstm_hidden_1", {}, false);
layers.lstm(fc_1_tmp1, lstm_w, lstm_b, lstm_cell_1, lstm_batch_gate_1,
lstm_hidden_1, lstm_batch_cell_pre_gate_1, nullptr, nullptr, true,
false, gate_activation, cell_activation, candidate_activation);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
return std::move(graph);
}
} // namespace fc_lstm_test
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2262,11 +2262,11 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -2262,11 +2262,11 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "conv2d_transpose", std::unordered_set<std::string>(
"elementwise_add", "elementwise_mul", {"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"fc", "fusion_gru", "gelu", "layer_norm", "elementwise_mul", "fc", "fusion_gru", "fusion_lstm", "gelu",
"matmul", "pool2d", "relu", "reshape2", "layer_norm", "matmul", "pool2d", "relu", "reshape2", "softmax",
"softmax", "sum", "transpose2"}); "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass_tester.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass_tester.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void TestFcRNNFusePass(const std::string& pass_name,
std::string activation = "tanh",
std::string gate_activation = "sigmoid",
std::string candidate_activation = "tanh") {
std::unique_ptr<ir::Graph> graph =
(pass_name == "fc_gru_fuse_pass"
? fc_gru_test::PrepareGraph(activation, gate_activation)
: fc_lstm_test::PrepareGraph(gate_activation, activation,
candidate_activation));
auto mkldnn_placement_pass_ =
PassRegistry::Instance().Get("mkldnn_placement_pass");
mkldnn_placement_pass_->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>({}));
graph->Set("__param_scope__", (pass_name == "fc_gru_fuse_pass"
? fc_gru_test::CreateParamScope()
: fc_lstm_test::CreateParamScope()));
graph.reset(mkldnn_placement_pass_->Apply(graph.release()));
auto check_num_mkldnn_nodes = [&](const std::unique_ptr<ir::Graph>& graph) {
int nodes_cout = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->GetAttrIfExists<bool>("use_mkldnn")) nodes_cout++;
}
}
return nodes_cout;
};
int num_mkldnn_nodes_before = check_num_mkldnn_nodes(graph);
int removed_mkldnn_nodes = 2;
// OneDNN fusion_gru and fusion_lstm supports only sigmoid as a gate
// activation and tanh as an activation and candidate_activation
if (activation != "tanh" || gate_activation != "sigmoid" ||
candidate_activation != "tanh")
removed_mkldnn_nodes += 2;
auto fc_rnn_fuse_pass_ = PassRegistry::Instance().Get(pass_name);
graph.reset(fc_rnn_fuse_pass_->Apply(graph.release()));
int num_mkldnn_nodes_after = check_num_mkldnn_nodes(graph);
PADDLE_ENFORCE_EQ(num_mkldnn_nodes_before - removed_mkldnn_nodes,
num_mkldnn_nodes_after,
platform::errors::PreconditionNotMet(
"The number of nodes with \"use_mkldnn\" attr after "
"passes is not as expected"));
}
TEST(FcGruFusePass, use_mkldnn) { TestFcRNNFusePass("fc_gru_fuse_pass"); }
TEST(FcGruFusePass, gru_unsupported_activations) {
TestFcRNNFusePass("fc_gru_fuse_pass", "relu", "sigmoid");
}
TEST(FcLstmFusePass, use_mkldnn) { TestFcRNNFusePass("fc_lstm_fuse_pass"); }
TEST(FcLstmFusePass, lstm_unsupported_activations) {
TestFcRNNFusePass("fc_lstm_fuse_pass", "tanh", "relu", "tanh");
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(mkldnn_placement_pass);
USE_PASS(fc_gru_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
...@@ -194,17 +194,20 @@ struct Layers { ...@@ -194,17 +194,20 @@ struct Layers {
} }
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int x_num_col_dims = 1, int y_num_col_dims = 1) { int x_num_col_dims = 1, int y_num_col_dims = 1,
bool use_mkldnn = false) {
AttributeMap attrs; AttributeMap attrs;
attrs["x_num_col_dims"] = x_num_col_dims; attrs["x_num_col_dims"] = x_num_col_dims;
attrs["y_num_col_dims"] = y_num_col_dims; attrs["y_num_col_dims"] = y_num_col_dims;
attrs["use_mkldnn"] = use_mkldnn;
return binary_op("mul", x, y, out, &attrs); return binary_op("mul", x, y, out, &attrs);
} }
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int axis = -1) { int axis = -1, bool use_mkldnn = false) {
AttributeMap attrs; AttributeMap attrs;
attrs["axis"] = axis; attrs["axis"] = axis;
attrs["use_mkldnn"] = use_mkldnn;
return binary_op("elementwise_add", x, y, out, &attrs); return binary_op("elementwise_add", x, y, out, &attrs);
} }
......
...@@ -38,7 +38,6 @@ void SetAnalysisConfig(AnalysisConfig *cfg, ...@@ -38,7 +38,6 @@ void SetAnalysisConfig(AnalysisConfig *cfg,
cfg->SwitchSpecifyInputNames(false); cfg->SwitchSpecifyInputNames(false);
cfg->SetCpuMathLibraryNumThreads(num_threads); cfg->SetCpuMathLibraryNumThreads(num_threads);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("mkldnn_placement_pass");
} }
std::vector<size_t> ReadSentenceLod(std::ifstream &file, size_t offset, std::vector<size_t> ReadSentenceLod(std::ifstream &file, size_t offset,
......
...@@ -249,6 +249,11 @@ void FusionLSTMOpMaker::Make() { ...@@ -249,6 +249,11 @@ void FusionLSTMOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<float>("Scale_data", AddAttr<float>("Scale_data",
"Scale to be used for int8 input/output data." "Scale to be used for int8 input/output data."
"Only used with MKL-DNN INT8.") "Only used with MKL-DNN INT8.")
......
...@@ -27,7 +27,7 @@ from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION ...@@ -27,7 +27,7 @@ from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION
"place does not support BF16 evaluation") "place does not support BF16 evaluation")
class TestFusionGRUBF16MKLDNNOp(OpTest): class TestFusionGRUBF16MKLDNNOp(OpTest):
def set_confs(self): def set_confs(self):
self.mkldnn_data_type = False pass
def test_check_output(self): def test_check_output(self):
for use_seq in {True, False}: for use_seq in {True, False}:
...@@ -48,6 +48,7 @@ class TestFusionGRUBF16MKLDNNOp(OpTest): ...@@ -48,6 +48,7 @@ class TestFusionGRUBF16MKLDNNOp(OpTest):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.origin_mode = False self.origin_mode = False
self.use_mkldnn = True self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.force_fp32_output = False self.force_fp32_output = False
self.weights_dtype = 'fp32' self.weights_dtype = 'fp32'
self.set_confs() self.set_confs()
...@@ -113,7 +114,8 @@ class TestFusionGRUBF16MKLDNNOp(OpTest): ...@@ -113,7 +114,8 @@ class TestFusionGRUBF16MKLDNNOp(OpTest):
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode, 'origin_mode': self.origin_mode,
'force_fp32_output': self.force_fp32_output, 'force_fp32_output': self.force_fp32_output,
'use_mkldnn': self.use_mkldnn 'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
} }
......
...@@ -35,6 +35,7 @@ class TestFusionGRUINT8MKLDNNOp(OpTest): ...@@ -35,6 +35,7 @@ class TestFusionGRUINT8MKLDNNOp(OpTest):
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.origin_mode = True self.origin_mode = True
self.use_mkldnn = True self.use_mkldnn = True
self.mkldnn_data_type = "int8"
self.force_fp32_output = True self.force_fp32_output = True
self.error_margin = 1e-5 self.error_margin = 1e-5
self.set_confs() self.set_confs()
...@@ -115,6 +116,7 @@ class TestFusionGRUINT8MKLDNNOp(OpTest): ...@@ -115,6 +116,7 @@ class TestFusionGRUINT8MKLDNNOp(OpTest):
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode, 'origin_mode': self.origin_mode,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
'force_fp32_output': self.force_fp32_output, 'force_fp32_output': self.force_fp32_output,
'Scale_data': scale_data, 'Scale_data': scale_data,
'Shift_data': shift_data, 'Shift_data': shift_data,
......
...@@ -27,7 +27,7 @@ from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru ...@@ -27,7 +27,7 @@ from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru
"place does not support BF16 evaluation") "place does not support BF16 evaluation")
class TestFusionLSTMBF16ONEDNNOp(OpTest): class TestFusionLSTMBF16ONEDNNOp(OpTest):
def set_confs(self): def set_confs(self):
self.mkldnn_data_type = False pass
def test_check_output(self): def test_check_output(self):
for use_seq in {True, False}: for use_seq in {True, False}:
...@@ -48,6 +48,7 @@ class TestFusionLSTMBF16ONEDNNOp(OpTest): ...@@ -48,6 +48,7 @@ class TestFusionLSTMBF16ONEDNNOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_mkldnn = True self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
self.force_fp32_output = False self.force_fp32_output = False
self.weights_dtype = 'fp32' self.weights_dtype = 'fp32'
self.set_confs() self.set_confs()
...@@ -130,7 +131,8 @@ class TestFusionLSTMBF16ONEDNNOp(OpTest): ...@@ -130,7 +131,8 @@ class TestFusionLSTMBF16ONEDNNOp(OpTest):
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
'candidate_activation': self.act_cand, 'candidate_activation': self.act_cand,
'force_fp32_output': self.force_fp32_output, 'force_fp32_output': self.force_fp32_output,
'use_mkldnn': self.use_mkldnn 'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
} }
......
...@@ -34,6 +34,7 @@ class TestFusionLSTMINT8MKLDNNOp(OpTest): ...@@ -34,6 +34,7 @@ class TestFusionLSTMINT8MKLDNNOp(OpTest):
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False # LSTM u8 doesn't support peepholes self.use_peepholes = False # LSTM u8 doesn't support peepholes
self.use_mkldnn = True self.use_mkldnn = True
self.mkldnn_data_type = "int8"
self.force_fp32_output = False self.force_fp32_output = False
self.error_margin = 1e-5 self.error_margin = 1e-5
self.set_confs() self.set_confs()
...@@ -117,6 +118,7 @@ class TestFusionLSTMINT8MKLDNNOp(OpTest): ...@@ -117,6 +118,7 @@ class TestFusionLSTMINT8MKLDNNOp(OpTest):
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'use_peepholes': self.use_peepholes, 'use_peepholes': self.use_peepholes,
'use_mkldnn': self.use_mkldnn, 'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
'force_fp32_output': self.force_fp32_output, 'force_fp32_output': self.force_fp32_output,
'Scale_data': scale_data, 'Scale_data': scale_data,
'Shift_data': shift_data, 'Shift_data': shift_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册