未验证 提交 6945a80b 编写于 作者: W Wilber 提交者: GitHub

cherry-pick 22551. test=develop test=release/1.7 (#22609)

[cherry-pick] #22551

当一个模型中有多个fc_lstm子图的时候,且其中fc共用了同一个persistable的bias,此时不应该将bias节点删除,只将非persistable的节点去除即可。
上级 a06883cd
...@@ -116,6 +116,8 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r ...@@ -116,6 +116,8 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass_tester.cc DEPS fc_lstm_fuse_pass framework_proto)
cc_test(test_fc_gru_fuse_pass SRCS fc_gru_fuse_pass_tester.cc DEPS fc_gru_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto) cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto) cc_test(test_repeated_fc_relu_fuse_pass SRCS repeated_fc_relu_fuse_pass_tester.cc DEPS repeated_fc_relu_fuse_pass framework_proto)
......
...@@ -127,8 +127,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -127,8 +127,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern);
// nodes need be removed // nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern); GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchResetHiddenPrev,
GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern); gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchHidden, gru_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
...@@ -138,7 +139,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -138,7 +139,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias); gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate, {mul, gru, elementwise_add, fc_out, mul_out, BatchGate,
BatchResetHiddenPrev, BatchHidden}); BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_gru_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "gru_fc_w", {});
AddVarToScope(param_scope, "gru_fc_b", {});
AddVarToScope(param_scope, "gru_w", {});
AddVarToScope(param_scope, "gru_b", {});
AddVarToScope(param_scope, "gru_batch_gate_0", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_0", {});
AddVarToScope(param_scope, "gru_batch_hidden_0", {});
AddVarToScope(param_scope, "gru_hidden_0", {});
AddVarToScope(param_scope, "gru_batch_gate_1", {});
AddVarToScope(param_scope, "gru_batch_reset_hidden_prev_1", {});
AddVarToScope(param_scope, "gru_batch_hidden_1", {});
AddVarToScope(param_scope, "gru_hidden_1", {});
return param_scope;
}
TEST(FCFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
// (a, gru_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, gru_fc_b) elementwise_add -> fc_0_tmp_1
// (fc_0_tmp_1,gru_w,gru_b gru -> gru_out_0
// (b, gru_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, gru_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,gru_w,gru_b) gru -> gru_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("gru_fc_w", {}, true);
auto* fc_b = layers.data("gru_fc_b", {}, true);
auto* gru_w = layers.data("gru_w", {}, true);
auto* gru_b = layers.data("gru_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* gru_batch_gate_0 = layers.data("gru_batch_gate_0", {}, false);
auto* gru_batch_reset_hidden_prev_0 =
layers.data("gru_batch_reset_hidden_prev_0", {}, false);
auto* gru_batch_hidden_0 = layers.data("gru_batch_hidden_0", {}, false);
auto* gru_hidden_0 = layers.data("gru_hidden_0", {}, false);
layers.gru(fc_0_tmp1, gru_w, gru_b, gru_batch_gate_0,
gru_batch_reset_hidden_prev_0, gru_batch_hidden_0, gru_hidden_0);
auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* gru_batch_gate_1 = layers.data("gru_batch_gate_1", {}, false);
auto* gru_batch_reset_hidden_prev_1 =
layers.data("gru_batch_reset_hidden_prev_1", {}, false);
auto* gru_batch_hidden_1 = layers.data("gru_batch_hidden_1", {}, false);
auto* gru_hidden_1 = layers.data("gru_hidden_1", {}, false);
layers.gru(fc_1_tmp1, gru_w, gru_b, gru_batch_gate_1,
gru_batch_reset_hidden_prev_1, gru_batch_hidden_1, gru_hidden_1);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_gru_fuse_pass");
pass->Set("use_gpu", new bool(true));
graph->Set("__param_scope__", CreateParamScope());
int num_nodes_before = graph->Nodes().size();
int num_gru_nodes_before = GetNumOpNodes(graph, "gru");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fuse_gru_nodes_after = GetNumOpNodes(graph, "fusion_gru");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6,
platform::errors::PreconditionNotMet(
"The number of nodes before and after "
"the fuse does not meet expectations"));
PADDLE_ENFORCE_EQ(
num_fuse_gru_nodes_after, 2,
platform::errors::PreconditionNotMet("The number of gru nodes before the "
"fuse does not meet expectations"));
PADDLE_ENFORCE_EQ(num_gru_nodes_before, num_fuse_gru_nodes_after,
platform::errors::PreconditionNotMet(
"The number of fusion_gru nodes does not meet "
"expectations after fuse"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fc_gru_fuse_pass);
...@@ -133,26 +133,30 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -133,26 +133,30 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(BatchCellPreAct, BatchCellPreAct, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, elementwise_add_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
fc_bias); fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, lstm, elementwise_add, fc_bias}); {mul, lstm, elementwise_add, mul_out, BatchGate, BatchCellPreAct});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out, lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
nullptr); nullptr);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes({mul, lstm}); std::unordered_set<const Node*> marked_nodes(
{mul, lstm, BatchGate, BatchCellPreAct});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} }
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "lstm_fc_w", {});
AddVarToScope(param_scope, "lstm_fc_b", {});
AddVarToScope(param_scope, "lstm_w", {});
AddVarToScope(param_scope, "lstm_b", {});
AddVarToScope(param_scope, "lstm_cell_0", {});
AddVarToScope(param_scope, "lstm_batch_gate_0", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_0", {});
AddVarToScope(param_scope, "lstm_hidden_0", {});
AddVarToScope(param_scope, "lstm_cell_1", {});
AddVarToScope(param_scope, "lstm_batch_gate_1", {});
AddVarToScope(param_scope, "lstm_batch_cell_pre_gate_1", {});
AddVarToScope(param_scope, "lstm_hidden_1", {});
return param_scope;
}
TEST(FCLSTMFusePass, basic) {
// inputs operator output
// --------------------------------------------------------
// (a, lstm_fc_w) mul -> fc_0_tmp_0
// (fc_0_tmp_0, lstm_fc_b) elementwise_add -> fc_0_tmp_1
// fc_0_tmp_1,lstm_w,lstm_b lstm -> lstm_out_0
// (b, lstm_fc_w) mul -> fc_1_tmp_0
// (fc_1_tmp_0, lstm_fc_b) elementwise_add -> fc_1_tmp_1
// (fc_1_tmp_1,lstm_w,lstm_b) lstm -> lstm_out_1
Layers layers;
auto* a = layers.data("a");
auto* b = layers.data("b");
auto* fc_w = layers.data("lstm_fc_w", {}, true);
auto* fc_b = layers.data("lstm_fc_b", {}, true);
auto* lstm_w = layers.data("lstm_w", {}, true);
auto* lstm_b = layers.data("lstm_b", {}, true);
auto* fc_0_tmp0 = layers.mul(a, fc_w);
auto* fc_0_tmp1 = layers.elementwise_add(fc_0_tmp0, fc_b);
auto* lstm_cell_0 = layers.data("lstm_cell_0", {}, false);
auto* lstm_batch_gate_0 = layers.data("lstm_batch_gate_0", {}, false);
auto* lstm_batch_cell_pre_gate_0 =
layers.data("lstm_batch_cell_pre_gate_0", {}, false);
auto* lstm_hidden_0 = layers.data("lstm_hidden_0", {}, false);
layers.lstm(fc_0_tmp1, lstm_w, lstm_b, lstm_cell_0, lstm_batch_gate_0,
lstm_hidden_0, lstm_batch_cell_pre_gate_0);
auto* fc_1_tmp0 = layers.mul(b, fc_w);
auto* fc_1_tmp1 = layers.elementwise_add(fc_1_tmp0, fc_b);
auto* lstm_cell_1 = layers.data("lstm_cell_1", {}, false);
auto* lstm_batch_gate_1 = layers.data("lstm_batch_gate_1", {}, false);
auto* lstm_batch_cell_pre_gate_1 =
layers.data("lstm_batch_cell_pre_gate_1", {}, false);
auto* lstm_hidden_1 = layers.data("lstm_hidden_1", {}, false);
layers.lstm(fc_1_tmp1, lstm_w, lstm_b, lstm_cell_1, lstm_batch_gate_1,
lstm_hidden_1, lstm_batch_cell_pre_gate_1);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fc_lstm_fuse_pass");
pass->Set("use_gpu", new bool(false));
graph->Set("__param_scope__", CreateParamScope());
int num_nodes_before = graph->Nodes().size();
int num_lstm_nodes_before = GetNumOpNodes(graph, "lstm");
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_fusion_lstm_nodes_after = GetNumOpNodes(graph, "fusion_lstm");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after - 6,
platform::errors::PreconditionNotMet(
"The number of nodes before and after "
"the fuse does not meet expectations"));
PADDLE_ENFORCE_EQ(num_fusion_lstm_nodes_after, 2,
platform::errors::PreconditionNotMet(
"The number of lstm nodes before the "
"fuse does not meet expectations"));
PADDLE_ENFORCE_EQ(num_lstm_nodes_before, num_fusion_lstm_nodes_after,
platform::errors::PreconditionNotMet(
"The number of fusion_gru nodes does "
"not meet expectations after fuse"));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fc_lstm_fuse_pass);
...@@ -120,6 +120,62 @@ struct Layers { ...@@ -120,6 +120,62 @@ struct Layers {
return out; return out;
} }
void lstm(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* cell,
VarDesc* batch_gate, VarDesc* hidden, VarDesc* batch_cell_pre_act,
VarDesc* h0 = nullptr, VarDesc* c0 = nullptr,
bool use_peepholes = true, bool is_reverse = false,
std::string gate_activation = "sigmoid",
std::string cell_activation = "tanh",
std::string candidate_activation = "tanh") {
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("lstm");
op->SetInput("Input", {input->Name()});
op->SetInput("Weight", {w->Name()});
op->SetInput("Bias", {bias->Name()});
if (h0) {
op->SetInput("H0", {h0->Name()});
}
if (c0) {
op->SetInput("C0", {c0->Name()});
}
op->SetOutput("Hidden", {hidden->Name()});
op->SetOutput("Cell", {cell->Name()});
op->SetOutput("BatchGate", {batch_gate->Name()});
op->SetOutput("BatchCellPreAct", {batch_cell_pre_act->Name()});
op->SetAttr("use_peepholes", use_peepholes);
op->SetAttr("is_reverse", is_reverse);
op->SetAttr("gate_activation", gate_activation);
op->SetAttr("cell_activation", cell_activation);
op->SetAttr("candidate_activation", candidate_activation);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
void gru(VarDesc* input, VarDesc* w, VarDesc* bias, VarDesc* batch_gate,
VarDesc* batch_reset_hidden_prev, VarDesc* batch_hidden,
VarDesc* hidden, VarDesc* h0 = nullptr, bool origin_mode = false,
bool is_reverse = false, std::string activation = "tanh",
std::string gate_activation = "sigmoid") {
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("gru");
op->SetInput("Input", {input->Name()});
op->SetInput("Weight", {w->Name()});
op->SetInput("Bias", {bias->Name()});
if (h0) {
op->SetInput("H0", {h0->Name()});
}
op->SetOutput("BatchGate", {batch_gate->Name()});
op->SetOutput("BatchResetHiddenPrev", {batch_reset_hidden_prev->Name()});
op->SetOutput("BatchHidden", {batch_hidden->Name()});
op->SetOutput("Hidden", {hidden->Name()});
op->SetAttr("origin_mode", origin_mode);
op->SetAttr("is_reverse", is_reverse);
op->SetAttr("activation", activation);
op->SetAttr("gate_activation", gate_activation);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr, VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr,
int x_num_col_dims = 1) { int x_num_col_dims = 1) {
AttributeMap attrs; AttributeMap attrs;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册