diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 9300573d7fbc7648f7e30ac97dc387e7249da1ff..bfc649017f19d67660bd11d590134cf56772bb27 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper) cc_library(graph_traits SRCS graph_traits.cc DEPS graph) -cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) -cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter) +cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) +cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector) +cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector) cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass) - +cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector) +cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) -cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) -cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto) +cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) +cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto) diff --git a/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..2876de88f174b1fa4ce0eacb8687e15e723bf1fc --- /dev/null +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc @@ -0,0 +1,273 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/inference/api/helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +struct Param { + std::string X = "concat_0.tmp_0"; + std::string C0 = "cell_init"; + std::string H0 = "hidden_init"; + std::string AttentionWeight = "attention_fc.w_0"; + std::string AttentionBias = "attention_fc.b_0"; + std::string AttentionScalar = "attention_output.w_0"; + std::string AttentionScalarBias = "attention_output.b_0"; + std::string LSTMWeight = "attention_w.new"; + std::string LSTMBias = "attention_b.new"; + std::string Hidden = "array_to_lod_tensor_0.tmp_0"; + std::string Cell = "at.cell.new"; + std::string AttentionedX = "at.x.new"; + std::string AttentionFCOut = "at.fc.new"; + std::string LSTMX = "at.lstmx.new"; + std::string LSTMOUT = "at.lstmout.new"; +}; + +void PrepareParameters(Graph* graph, const Param& param); + +void FindWhileOp(Graph* graph) { + GraphPatternDetector gpd; + std::unordered_set fused_external_ops( + {35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48, + 57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51}); + + gpd.mutable_pattern()->NewNode( + [&](Node* n) { return fused_external_ops.count(n->id()); }, "while"); + + if (!graph->Has(kGraphvizMarkedNodeAttr)) { + graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t); + } + auto& marked_nodes = + graph->Get(kGraphvizMarkedNodeAttr); + + auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + auto* while_pat_node = gpd.pattern().RetriveNode("while"); + auto* while_node = subgraph.at(while_pat_node); + marked_nodes.insert(while_node); + }; + gpd(graph, handle); + + Param param; + // Add AttentionLSTM node + OpDesc op_desc; + op_desc.SetType("attention_lstm"); + +#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x}); +#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x}); + OP_SET_IN(X); + OP_SET_IN(C0); + OP_SET_IN(H0); + OP_SET_IN(AttentionWeight); + OP_SET_IN(AttentionBias); + OP_SET_IN(AttentionScalar); + OP_SET_IN(AttentionScalarBias); + OP_SET_IN(LSTMWeight); + OP_SET_IN(LSTMBias); + + OP_SET_OUT(Hidden); + OP_SET_OUT(Cell); + OP_SET_OUT(AttentionedX); + OP_SET_OUT(AttentionFCOut); + OP_SET_OUT(LSTMX); + OP_SET_OUT(LSTMOUT); +#undef OP_SET_IN +#undef OP_SET_OUT + + auto* X = graph->RetriveNode(34); + auto* LSTMOUT = graph->RetriveNode(81); + auto* cell_init = graph->RetriveNode(6); + auto* hidden_init = graph->RetriveNode(8); + +#define LINK_TO(node0, node1) \ + node0->outputs.push_back(node1); \ + node1->inputs.push_back(node0); + + auto* lstm_op = graph->CreateOpNode(&op_desc); + PrepareParameters(graph, param); + + LINK_TO(X, lstm_op); + LINK_TO(cell_init, lstm_op); + LINK_TO(hidden_init, lstm_op); + LINK_TO(lstm_op, LSTMOUT); + + GraphSafeRemoveNodes(graph, marked_nodes); +} + +#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x); +#define CHECK_P2(x0, x1) \ + CHECK_P1(x0); \ + CHECK_P1(x1); +#define CHECK_P3(x0, x1, x2) \ + CHECK_P2(x0, x1); \ + CHECK_P1(x2); +#define CHECK_P4(x0, x1, x2, x3) \ + CHECK_P3(x0, x1, x2); \ + CHECK_P1(x3); +#define CHECK_P5(x0, x1, x2, x3, x4) \ + CHECK_P4(x0, x1, x2, x3); \ + CHECK_P1(x4); + +void PrepareLSTMWeight(const LoDTensor& W_forget_w0, + const LoDTensor& W_forget_w1, + const LoDTensor& W_input_w0, const LoDTensor& W_input_w1, + const LoDTensor& W_output_w0, + const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0, + const LoDTensor& W_cell_w1, LoDTensor* out); + +void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, + const LoDTensor& B_output, const LoDTensor& B_cell, + LoDTensor* out); + +void PrepareParameters(Graph* graph, const Param& param) { + // Check parameters + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); + auto* scope = graph->Get(kParamScopeAttr); + + // Create new parameters. + scope->Var(param.LSTMWeight)->GetMutable(); + scope->Var(param.LSTMBias)->GetMutable(); + scope->Var(param.Hidden)->GetMutable(); + scope->Var(param.Cell)->GetMutable(); + scope->Var(param.AttentionedX)->GetMutable(); + scope->Var(param.AttentionFCOut)->GetMutable(); + scope->Var(param.LSTMX)->GetMutable(); + scope->Var(param.LSTMOUT)->GetMutable(); + +#define GATE_W(name__) \ + auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ + auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ + auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ + CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ + VLOG(4) << #name__ "_w0" \ + << " shape: " << W_##name__##_w0->Get().dims(); \ + VLOG(4) << #name__ "_w1" \ + << " shape: " << W_##name__##_w1->Get().dims(); \ + VLOG(4) << #name__ "_b0" \ + << " shape: " << W_##name__##_b0->Get().dims(); \ + auto& W_##name__##_w0_t = W_##name__##_w0->Get(); \ + auto& W_##name__##_w1_t = W_##name__##_w1->Get(); \ + auto& W_##name__##_b0_t = W_##name__##_b0->Get(); + + GATE_W(forget); + GATE_W(input); + GATE_W(output); + GATE_W(c); +#undef GATE_W + + auto* attention_fc_w = scope->FindVar("attention_fc.w_0"); + auto* attention_fc_b = scope->FindVar("attention_fc.b_0"); + auto* attention_output_w = scope->FindVar("attention_output.w_0"); + auto* attention_output_b = scope->FindVar("attention_output.b_0"); + CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w, + attention_output_b); + + auto* lstm_weight = scope->Var(param.LSTMWeight); + auto* lstm_weight_t = lstm_weight->GetMutable(); + auto* lstm_bias = scope->Var(param.LSTMBias); + auto* lstm_bias_t = lstm_bias->GetMutable(); + + // reshape attention_bias + auto* attention_bias_t = + scope->FindVar(param.AttentionBias)->GetMutable(); + PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); + attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); + + auto* attention_scalar_bias_t = + scope->FindVar(param.AttentionScalarBias)->GetMutable(); + attention_scalar_bias_t->Resize( + make_ddim({1, attention_scalar_bias_t->dims()[0]})); + + PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t, + W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t, + lstm_weight_t); + PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t, + lstm_bias_t); +} + +// Prepare parameters +void PrepareLSTMWeight(const LoDTensor& W_forget_w0, + const LoDTensor& W_forget_w1, + const LoDTensor& W_input_w0, const LoDTensor& W_input_w1, + const LoDTensor& W_output_w0, + const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0, + const LoDTensor& W_cell_w1, LoDTensor* out) { + int D = W_forget_w0.dims()[0]; + int M = W_forget_w1.dims()[0]; + out->Resize(make_ddim({D + M, 4 * D})); + VLOG(3) << "LSTMWeight resized to " << out->dims(); + + float* out_data = out->mutable_data(platform::CPUPlace()); + std::array tensors( + {W_forget_w0.data(), W_input_w0.data(), + W_output_w0.data(), W_cell_w0.data()}); + std::array tensors1( + {W_forget_w1.data(), W_input_w1.data(), + W_output_w1.data(), W_cell_w1.data()}); + + for (int row = 0; row < D; row++) { + for (int col = 0; col < 4; col++) { + float* dst = out_data + 4 * D * row + D * col; + const float* src = tensors[col] + D * row; + memcpy(dst, src, D * sizeof(float)); + } + } + + for (int row = 0; row < M; row++) { + for (int col = 0; col < 4; col++) { + float* dst = out_data + 4 * D * (D + row) + D * col; + const float* src = tensors1[col] + D * row; + memcpy(dst, src, D * sizeof(float)); + } + } +} + +void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, + const LoDTensor& B_output, const LoDTensor& B_cell, + LoDTensor* out) { + std::array tensors( + {B_forget.data(), B_input.data(), B_output.data(), + B_cell.data()}); + + PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1); + int D = B_forget.dims()[0]; + out->Resize(make_ddim({1, 4 * D})); + auto* out_data = out->mutable_data(platform::CPUPlace()); + for (size_t i = 0; i < tensors.size(); i++) { + memcpy(out_data + D * i, tensors[i], D * sizeof(float)); + } +} + +// Parameters + +std::unique_ptr AttentionLSTMFusePass::ApplyImpl( + std::unique_ptr graph) const { + PDPattern external_pattern, subblock_pattern; + + FindWhileOp(graph.get()); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(attention_lstm_fuse_pass, + paddle::framework::ir::AttentionLSTMFusePass); diff --git a/paddle/fluid/inference/analysis/dot.cc b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h similarity index 62% rename from paddle/fluid/inference/analysis/dot.cc rename to paddle/fluid/framework/ir/attention_lstm_fuse_pass.h index d5471ffcb594a6915e9e65c0fee5adc5f5bdf40c..a756dfc1b98e1de55c809c73e2c4df1e628950ae 100644 --- a/paddle/fluid/inference/analysis/dot.cc +++ b/paddle/fluid/framework/ir/attention_lstm_fuse_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,12 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/inference/analysis/dot.h" +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" namespace paddle { -namespace inference { -namespace analysis { -size_t Dot::counter = 0; -} // namespace analysis -} // namespace inference +namespace framework { +namespace ir { + +class AttentionLSTMFusePass : public FusePassBase { + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index f4327742eac843f27385c165216ce48ceb97ea71..201160f29df1ee5473ba5e6cf434fa246e015a12 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) { }, "elementwise_add_out"); - pattern->AddEdge(mul_parameter_var, mul_op); - pattern->AddEdge(mul_tmp_input_var, mul_op); - pattern->AddEdge(mul_op, mul_out_var); - pattern->AddEdge(mul_out_var, elementwise_add_op); - pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op); - pattern->AddEdge(elementwise_add_op, elementwise_add_out_var); + mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var}) + .LinksTo({mul_out_var}); + elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var}) + .LinksTo({elementwise_add_out_var}); } // Replace the node `from` in the links to `to` @@ -125,7 +123,7 @@ std::unique_ptr FCFusePass::ApplyImpl( std::unordered_set nodes2delete; - GraphPatternDetecter gpd; + GraphPatternDetector gpd; BuildFCPattern(gpd.mutable_pattern()); #define GET_NODE(id) \ @@ -134,7 +132,7 @@ std::unique_ptr FCFusePass::ApplyImpl( auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); - auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph, + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "handle FC fuse"; // Currently, there is no FC op available, so I will just simulate the diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.h b/paddle/fluid/framework/ir/fc_fuse_pass.h index eb43dd4486cda578804fb9f6438c67e9e4a03091..31ed0e362f760319130135ad49fe2bb4e68e6786 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_fuse_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..daecf3b407c5b40c0ad6c3a75d7fbad3fe45c664 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -0,0 +1,126 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr FCLstmFusePass::ApplyImpl( + std::unique_ptr graph) const { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + std::unordered_set fused_ops({// first lstm + 13, 15, 16, + // second lstm + 23, 25, 26}); + + pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); }, + "any_node"); + + std::unordered_set marked_nodes; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + + auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node")); + marked_nodes.insert(id); + }; + gpd(graph.get(), handler); + + // Create New OpDesc + auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h, + int bias, int hidden, int cell, int xx) { +#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); + GET_NODE(input); + GET_NODE(weight_x); + GET_NODE(weight_h); + GET_NODE(bias); + GET_NODE(hidden); + GET_NODE(cell); + GET_NODE(xx); + GET_NODE(lstm); + + OpDesc op_desc; + op_desc.SetType("fusion_lstm"); +#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); + SET_IN(X, input); + SET_IN(WeightX, weight_x); + SET_IN(WeightH, weight_h); + SET_IN(Bias, bias); +#undef GET_NODE +#undef SET_IN + + LOG(INFO) << "hidden_n: " << hidden_n->Name(); + LOG(INFO) << "cell: " << cell_n->Name(); + LOG(INFO) << "xx: " << xx_n->Name(); + + op_desc.SetInput("H0", {}); + op_desc.SetInput("C0", {}); + op_desc.SetOutput("Hidden", {hidden_n->Name()}); + op_desc.SetOutput("Cell", {cell_n->Name()}); + op_desc.SetOutput("XX", {xx_n->Name()}); + op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"}); + op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"}); + op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); + op_desc.SetAttr("use_peepholes", false); + auto* op = graph->CreateOpNode(&op_desc); + +#define LINK_TO(a, b) \ + a->outputs.push_back(b); \ + b->inputs.push_back(a); + LINK_TO(input_n, op); + LINK_TO(weight_x_n, op); + LINK_TO(weight_h_n, op); + LINK_TO(bias_n, op); + LINK_TO(op, hidden_n); +#undef LINK_TO + return op; + + }; + + lstm_creator(16, 12, 14, 18, 17, 22, 21, 19); + lstm_creator(26, 12, 24, 28, 27, 32, 31, 29); + + // remove all the nodes + + for (auto* node : marked_nodes) { + graph->RemoveNode(const_cast(node)); + } + + for (auto* node : graph->Nodes()) { + for (auto it = node->inputs.begin(); it != node->inputs.end();) { + if (marked_nodes.count(*it)) { + it = const_cast(node)->inputs.erase(it); + } else + it++; + } + for (auto it = node->outputs.begin(); it != node->outputs.end();) { + if (marked_nodes.count(*it)) { + it = const_cast(node)->outputs.erase(it); + } else + it++; + } + } + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass); diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..74b08ae558b12c9328db58687cd01edbc37291a8 --- /dev/null +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class FCLstmFusePass : public Pass { + public: + virtual ~FCLstmFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h new file mode 100644 index 0000000000000000000000000000000000000000..bf6a0ae8274cecc785ffb269b0b574a42ee7d418 --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_pass_base.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace ir { + +static const char kParamScopeAttr[] = "param_scope"; + +class FusePassBase : public Pass { + public: + void Init(Graph* graph) const { graph_ = graph; } + + Scope* param_scope() const { + PADDLE_ENFORCE(graph_->Has(kParamScopeAttr)); + return graph_->Get(kParamScopeAttr); + } + + virtual ~FusePassBase() {} + + protected: + mutable Graph* graph_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 0d27be5fc007746d6ca41ff0dbcea5c5f45599ef..b696489565ccaf83947d8ecd730c24c3bf22b5c8 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -99,13 +99,13 @@ class Graph { // Create a normal variable with non-null VarDesc. ir::Node *CreateVarNode(VarDesc *var_desc) { PADDLE_ENFORCE(var_desc); - return AddNode(new ir::Node(var_desc)); + return AddNode(new ir::Node(var_desc, node_count_++)); } // Create a normal runnable operator with OpDesc. ir::Node *CreateOpNode(OpDesc *op_desc) { PADDLE_ENFORCE(op_desc); - return AddNode(new ir::Node(op_desc)); + return AddNode(new ir::Node(op_desc, node_count_++)); } // Create a control dependency var that connects 2 operations. The @@ -115,13 +115,14 @@ class Graph { // TODO(panyx0718): control var name should be really unique. const std::string name = string::Sprintf( "%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); - return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); + return AddNode( + new ir::Node(name, ir::Node::Type::kVariable, node_count_++)); } // A more free style way of creating a graph node. Mostly use for test // or "copy" from another node. Avoid using it if possible. ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { - return AddNode(new ir::Node(name, type)); + return AddNode(new ir::Node(name, type, node_count_++)); } // Clear all node information of the graph and return the ownership of the @@ -142,12 +143,20 @@ class Graph { nodes_.erase(node); } + Node *RetriveNode(int id) { + auto it = id2node_.find(id); + if (it != id2node_.end()) return it->second; + return nullptr; + } + private: // This method takes ownership of `node`. ir::Node *AddNode(ir::Node *node) { PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); nodes_[node].reset(node); node_set_.insert(node); + PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id()); + id2node_[node->id()] = node; return node; } @@ -157,6 +166,8 @@ class Graph { std::map> attr_dels_; std::map> nodes_; std::unordered_set node_set_; + std::map id2node_; + int node_count_{0}; }; bool IsControlDepVar(const ir::Node &var); diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index dc81a2cac585b50b81f79f8f204ce1145d93eab0..62f94a1c0e5a300438bbe5fea34b9a07df5d9ebf 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -103,10 +103,10 @@ std::map> BuildOperationAdjList( for (auto &var : n->inputs) { for (auto &adj_n : var->inputs) { PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); - adj_list[n].insert(adj_n); VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) << " -> " << n->Name() << reinterpret_cast(n) << " via " << var->Name() << reinterpret_cast(var); + adj_list[n].insert(adj_n); } } } diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc similarity index 71% rename from paddle/fluid/framework/ir/graph_pattern_detecter.cc rename to paddle/fluid/framework/ir/graph_pattern_detector.cc index e197861251fe5c9f98eaaba2a10b4af371dcbcba..dce4be8ff04204a134441410646c9a01b5dd40a3 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detecter.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/platform/enforce.h" @@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { name); } - nodes_.emplace_back(new PDNode(std::move(teller), name)); + nodes_.emplace_back(new PDNode(std::move(teller), this, name)); auto* cur = nodes_.back().get(); node_map_[name] = cur; return cur; @@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { edges_.emplace_back(a, b); } -void GraphPatternDetecter::operator()(Graph* graph, - GraphPatternDetecter::handle_t handler) { +void GraphPatternDetector::operator()(Graph* graph, + GraphPatternDetector::handle_t handler) { if (!MarkPDNodesInGraph(*graph)) return; auto subgraphs = DetectPatterns(); UniquePatterns(&subgraphs); RemoveOverlappedMatch(&subgraphs); + LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern"; + int id = 0; for (auto& g : subgraphs) { + LOG(INFO) << "optimizing #" << id++ << " subgraph"; handler(g, graph); } } -bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { +bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { VLOG(4) << "mark pdnodes in graph"; if (graph.Nodes().empty()) return false; @@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) { return false; } -std::vector -GraphPatternDetecter::DetectPatterns() { +std::vector +GraphPatternDetector::DetectPatterns() { // Init empty subgraphs. - std::vector result; + std::vector result; std::vector init_groups; - PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); - auto* first_pnode = pattern_.edges().front().first; + std::array, 2> bi_records; + // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); + auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() + : pattern_.edges().front().first; if (!pdnodes2nodes_.count(first_pnode)) return result; for (auto* node : pdnodes2nodes_[first_pnode]) { HitGroup group; @@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() { } int step = 0; - std::array, 2> bi_records; bi_records[0] = std::move(init_groups); // Extend a PDNode to subgraphs by deducing the connection relations defined @@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() { auto& pre_groups = bi_records[step % 2]; auto& cur_groups = bi_records[1 - (step++ % 2)]; cur_groups.clear(); + if (pre_groups.empty()) break; // source -> target for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* target : pdnodes2nodes_[edge.second]) { @@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() { } for (auto& group : bi_records[step % 2]) { - GraphPatternDetecter::subgraph_t subgraph; + GraphPatternDetector::subgraph_t subgraph; for (auto& role : group.roles) { subgraph.emplace(role.first, role.second); } @@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() { return result; } -void GraphPatternDetecter::UniquePatterns( - std::vector* subgraphs) { +void GraphPatternDetector::UniquePatterns( + std::vector* subgraphs) { if (subgraphs->empty()) return; - std::vector result; + std::vector result; std::unordered_set set; for (auto& g : *subgraphs) { @@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns( *subgraphs = result; } -void GraphPatternDetecter::RemoveOverlappedMatch( +void GraphPatternDetector::RemoveOverlappedMatch( std::vector* subgraphs) { std::vector result; std::unordered_set node_set; @@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch( *subgraphs = result; } +std::string PDPattern::DotString() const { + using inference::analysis::Dot; + Dot dot; + int id = 0; + // Create Nodes + std::unordered_map node2dot; + for (const auto& node : nodes()) { + std::string node_id = "Node" + std::to_string(id++); + dot.AddNode(node_id, {}, node->name()); + node2dot[node.get()] = node_id; + } + // Create Edges + for (const auto& edge : edges()) { + if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) { + LOG(ERROR) << "no node " << edge.first << " " << edge.second; + continue; + } + auto& src = node2dot.at(edge.first); + auto& trg = node2dot.at(edge.second); + dot.AddEdge(src, trg, {}); + } + return dot.Build(); +} + +PDNode& PDNode::LinksTo(const std::vector& others) { + // extend outlinks. + for (PDNode* x : others) { + pattern_->AddEdge(this, x); + } + return *this; +} + +PDNode& PDNode::LinksFrom(const std::vector& others) { + // extend outlinks. + for (PDNode* x : others) { + pattern_->AddEdge(x, this); + } + return *this; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter.h b/paddle/fluid/framework/ir/graph_pattern_detector.h similarity index 72% rename from paddle/fluid/framework/ir/graph_pattern_detecter.h rename to paddle/fluid/framework/ir/graph_pattern_detector.h index 68c39902b5a79bf25ca7f08529a958274ac64e33..0ac34a57aacdc4fcd3d6bcaa0b72b1d6dabb3abd 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detecter.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -21,12 +21,14 @@ #include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/inference/analysis/dot.h" namespace paddle { namespace framework { namespace ir { +class PDPattern; -// Some basic torminolygies: +// Some basic terminologies: // - PDPattern: a pattern defined as a data flow graph. // - PDNode: the node in the pattern, each PDNode represents an `ir::Node` // that meets some conditions defined in `PDNode.teller`. @@ -36,30 +38,43 @@ namespace ir { struct PDNode { // tell whether an ir::Node* is a candidation for a PDNode. using teller_t = std::function; + enum class Type { kOp, kVar }; - PDNode(teller_t&& teller, const std::string& name = "") - : teller_(teller), name_(name) { - PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); - } - - PDNode(PDNode&& other) = default; - - std::vector inlinks; - std::vector outlinks; + // this link to others + PDNode& LinksTo(const std::vector& others); + PDNode& LinksFrom(const std::vector& others); bool Tell(Node* node) const { PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); return teller_(node); } + bool IsOp() const { return type_ == Type::kOp; } + bool IsVar() const { return type_ == Type::kVar; } + const std::string& name() const { return name_; } PDNode(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete; private: + PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "", + Type type = Type::kVar) + : teller_(std::move(teller)), + pattern_(pattern), + name_(name), + type_(type) { + PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); + } + + PDNode(PDNode&& other) = default; + + friend class PDPattern; + teller_t teller_; + PDPattern* pattern_; std::string name_; + Type type_; }; /* @@ -102,6 +117,8 @@ class PDPattern { const std::vector>& nodes() const { return nodes_; } const std::vector& edges() const { return edges_; } + std::string DotString() const; + private: #ifdef PADDLE_WITH_TESTING FRIEND_TEST(PDPattern, AddEdge); @@ -117,7 +134,7 @@ class PDPattern { }; /* - * GraphPatternDetecter helps to detect the specific patterns in the graph. + * GraphPatternDetector helps to detect the specific patterns in the graph. * Input a pattern, output a list of the matched subgraphs/nodes. * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). * @@ -129,7 +146,7 @@ class PDPattern { * * Usage: * // Create a detector - * GraphPatternDetecter detector; + * GraphPatternDetector detector; * // Define the detector's pattern, by adding PDNode and define the edges. * auto* node0 = detector.mutable_pattern().AddNode(...) * auto* node1 = detector.mutable_pattern().AddNode(...) @@ -138,11 +155,11 @@ class PDPattern { * detector.mutable_pattern().AddEdge(node0, node1); * // Create an handler, to define the behavior of treating the filtered * // subgraphs that comply with the patterns. - * GraphPatternDetecter::handle_t handler = some labmda + * GraphPatternDetector::handle_t handler = some labmda * // Execute the detector. * detector(&graph, handler); */ -class GraphPatternDetecter { +class GraphPatternDetector { public: using subgraph_t = std::unordered_map; @@ -177,10 +194,62 @@ class GraphPatternDetecter { using hit_rcd_t = std::pair; PDPattern pattern_; - std::vector marked_records_; std::unordered_map> pdnodes2nodes_; }; +// some helper methods. + +// Op's input. +static bool VarLinksToOp(Node* node, const std::string& op_type) { + for (auto* out : node->outputs) { + if (out->IsOp() && out->Op()->Type() == op_type) { + return true; + } + } + return false; +} + +// Op's output. +static bool VarLinksFromOp(Node* node, const std::string& op_type) { + for (auto* out : node->inputs) { + if (out->IsOp() && out->Op()->Type() == op_type) { + return true; + } + } + return false; +} + +// Check whether a var node is a op node's nth input. +static bool IsNthInput(Node* var, Node* op, const std::string& argument, + size_t nth) { + PADDLE_ENFORCE(var->IsVar()); + PADDLE_ENFORCE(op->IsOp()); + if (op->inputs.size() <= nth) return false; + return var->Name() == op->Op()->Input(argument)[nth]; +} + +static void GraphSafeRemoveNodes(Graph* graph, + const std::unordered_set& nodes) { + for (auto* node : nodes) { + graph->RemoveNode(const_cast(node)); + } + + for (auto* node : graph->Nodes()) { + for (auto it = node->inputs.begin(); it != node->inputs.end();) { + if (nodes.count(*it)) { + it = const_cast(node)->inputs.erase(it); + } else + it++; + } + for (auto it = node->outputs.begin(); it != node->outputs.end();) { + if (nodes.count(*it)) { + it = const_cast(node)->outputs.erase(it); + } else + it++; + } + } +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc similarity index 95% rename from paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc rename to paddle/fluid/framework/ir/graph_pattern_detector_tester.cc index 06f9df5546910f492c9dd1da3e694623898d3d1d..a4d0646230c0fdfb7e1970523799e7db10c75538 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detecter_tester.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector_tester.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include @@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) { } TEST(GraphPatternDetecter, MarkPDNodesInGraph) { - GraphPatternDetecter x; + GraphPatternDetector x; // mark o2, o3, v2 // The pattern is a graph: @@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) { Graph graph(program); BuildGraph(&graph); - GraphPatternDetecter x; + GraphPatternDetector x; // The pattern is a graph: // op -> var @@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) { x.mutable_pattern()->AddEdge(any_var, any_op1); int count = 0; - GraphPatternDetecter::handle_t handle = [&]( - const GraphPatternDetecter::subgraph_t& s, Graph* g) { + GraphPatternDetector::handle_t handle = [&]( + const GraphPatternDetector::subgraph_t& s, Graph* g) { LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> " << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name(); count++; diff --git a/paddle/fluid/framework/ir/graph_viz_pass.cc b/paddle/fluid/framework/ir/graph_viz_pass.cc index e7ff0c1dac134334e3baad88886862ebff0fe367..3a114c6a237ea4411a8c4dd4b3ee6a00b7729d7c 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.cc +++ b/paddle/fluid/framework/ir/graph_viz_pass.cc @@ -16,11 +16,13 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/inference/analysis/dot.h" namespace paddle { namespace framework { namespace ir { static const char kGraphVizPath[] = "graph_viz_path"; +using inference::analysis::Dot; std::unique_ptr GraphVizPass::ApplyImpl( std::unique_ptr graph) const { @@ -30,41 +32,65 @@ std::unique_ptr GraphVizPass::ApplyImpl( PADDLE_ENFORCE(fout->good()); std::ostream& sout = *fout; - size_t var_id = 0; - std::unordered_map vars; - - sout << "digraph G {\n"; - - for (const ir::Node* n : graph->Nodes()) { - if (n->NodeType() != ir::Node::Type::kVariable) continue; - size_t cur_var_id = var_id++; - vars[n] = cur_var_id; - - sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]" - << std::endl; - } - - size_t op_id = 0; - for (const ir::Node* n : graph->Nodes()) { - if (n->NodeType() != ir::Node::Type::kOperation) continue; - std::string op_name = "op_" + std::to_string(op_id++); - sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]" - << std::endl; - for (auto in : n->inputs) { - std::string var_name = "var_" + std::to_string(vars[in]); - sout << var_name << " -> " << op_name << std::endl; + std::unordered_map node2dot; + + Dot dot; + + std::vector op_attrs({Dot::Attr("style", "filled"), + Dot::Attr("shape", "box"), + Dot::Attr("fillcolor", "red")}); + std::vector var_attrs({Dot::Attr("style", "filled,rounded"), + // Dot::Attr("shape", "diamond"), + Dot::Attr("fillcolor", "yellow")}); + + std::vector marked_op_attrs({Dot::Attr("style", "filled"), + Dot::Attr("shape", "box"), + Dot::Attr("fillcolor", "lightgray")}); + std::vector marked_var_attrs( + {Dot::Attr("style", "filled,rounded"), + // Dot::Attr("shape", "diamond"), + Dot::Attr("fillcolor", "lightgray")}); + + auto marked_nodes = ConsumeMarkedNodes(graph.get()); + // Create nodes + for (const Node* n : graph->Nodes()) { + std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")"; + if (n->IsOp()) { + decltype(op_attrs) attr = + marked_nodes.count(n) ? marked_op_attrs : op_attrs; + dot.AddNode(node_id, attr, node_id); + } else if (n->IsVar()) { + decltype(op_attrs) attr = + marked_nodes.count(n) ? marked_var_attrs : var_attrs; + dot.AddNode(node_id, attr, node_id); } - - for (auto out : n->outputs) { - std::string var_name = "var_" + std::to_string(vars[out]); - sout << op_name << " -> " << var_name << std::endl; + node2dot[n] = node_id; + } + // Create edges + for (const Node* n : graph->Nodes()) { + const auto& src_id = node2dot.at(n); + for (auto* out : n->outputs) { + const auto& trg_id = node2dot.at(out); + dot.AddEdge(src_id, trg_id, {}); } } - sout << "}\n"; + sout << dot.Build(); + return graph; } +GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes( + Graph* graph) const { + marked_nodes_t res; + if (graph->Has(kGraphvizMarkedNodeAttr)) { + auto& attr = graph->Get(kGraphvizMarkedNodeAttr); + res = attr; + attr.clear(); + } + return res; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_viz_pass.h b/paddle/fluid/framework/ir/graph_viz_pass.h index 1fd8c8a26e9581ccf605d4271a49ec2e90d8b997..8d885cb9e4ee6e01de386b0f22423988dbe60ca6 100644 --- a/paddle/fluid/framework/ir/graph_viz_pass.h +++ b/paddle/fluid/framework/ir/graph_viz_pass.h @@ -27,10 +27,19 @@ namespace paddle { namespace framework { namespace ir { +const char kGraphvizMarkedNodeAttr[] = "__graphviz__marked_node__"; + class GraphVizPass : public Pass { + public: + using marked_nodes_t = std::unordered_set; + protected: std::unique_ptr ApplyImpl( std::unique_ptr graph) const override; + + // Tell whether there are any marked nodes in the graph. Consume the + // corresponding attribute. + marked_nodes_t ConsumeMarkedNodes(Graph* graph) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index aab3180e7e5ece5de5f5227e76f78687700fed87..6d40e3852295a722a4f99946202ecde2bc4b82e9 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -29,20 +29,26 @@ class Node { enum class Type { kOperation, kVariable }; static constexpr char kControlDepVarName[] = "__control_var"; - explicit Node(const std::string& name, Type type) - : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} + explicit Node(const std::string& name, Type type, int id = -1) + : name_(name), + var_desc_(nullptr), + op_desc_(nullptr), + type_(type), + id_(id) {} - explicit Node(VarDesc* var_desc) + explicit Node(VarDesc* var_desc, int id = -1) : name_(var_desc->Name()), var_desc_(new VarDesc(*var_desc)), op_desc_(nullptr), - type_(Type::kVariable) {} + type_(Type::kVariable), + id_(id) {} - explicit Node(OpDesc* op_desc) + explicit Node(OpDesc* op_desc, int id = -1) : name_(op_desc->Type()), var_desc_(nullptr), op_desc_(new OpDesc(*op_desc, op_desc->Block())), - type_(Type::kOperation) {} + type_(Type::kOperation), + id_(id) {} Type NodeType() const { return type_; } @@ -58,6 +64,8 @@ class Node { return op_desc_.get(); } + int id() const { return id_; } + bool IsOp() const { return type_ == Type::kOperation; } bool IsVar() const { return type_ == Type::kVariable; } @@ -69,6 +77,7 @@ class Node { std::unique_ptr var_desc_; std::unique_ptr op_desc_; Type type_; + int id_; private: DISABLE_COPY_AND_ASSIGN(Node); diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bb5c232e5c2269643ddef7ed9c938e0332f7274 --- /dev/null +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.cc @@ -0,0 +1,256 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/lod_tensor.h" + +namespace paddle { +namespace framework { +namespace ir { + +struct FuseExpr {}; + +// sequence expand, concat fuse pattern, return concat's output +PDNode* BuildSeqExpandConcatPattern(PDPattern* pattern) { + // The following operators will be fused: + // concat + // sequence_expand + // sequence_expand + + // The following variables will be treat as inputs: + // concat mid input, 0th input for fused op + // sequence_expand input, 1th input for fused op + // sequence_expand input, 2th input for fused op + + // The following variables will be treat as outputs: + // concat output + + // So the following variables will be removed: + // sequence-expand output + // sequence-expand output + + // Three operators + auto* sequence_expand0 = pattern->NewNode( + [](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "sequence_expand"; + }, + "sequence_expand0"); + + auto* sequence_expand1 = pattern->NewNode( + [](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "sequence_expand"; + }, + "sequence_expand1"); + + auto* concat = pattern->NewNode( + [](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "concat" && // basic check + x->Op()->Input("X").size() == 3; // Special case + }, + "concat"); + + auto* sequence_expand0_in = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && VarLinksToOp(x, "sequence_expand"); + }, + "sequence_expand0_in"); + auto* sequence_expand1_in = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && VarLinksToOp(x, "sequence_expand"); + }, + "sequence_expand1_in"); + + // The variables + auto* sequence_expand0_out = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && + VarLinksFromOp(x, "sequence_expand") && // basic check + VarLinksToOp(x, "concat") && // is concat's input + IsNthInput(x, x->outputs[0], "X", 1); // X[0] + }, + "sequence_expand0_out"); + + auto* sequence_expand1_out = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && + VarLinksFromOp(x, "sequence_expand") && // basic check + VarLinksToOp(x, "concat") && // is concat's input + IsNthInput(x, x->outputs[0], "X", 2); // x[2] + }, + "sequence_expand1_out"); + + auto* concat_in0 = pattern->NewNode( + [](Node* x) { return x && x->IsVar() && VarLinksToOp(x, "concat"); }, + "concat_in0"); + + auto* concat_out = pattern->NewNode( + [](Node* x) { return x && x->IsVar() && VarLinksFromOp(x, "concat"); }, + "concat_out"); + + // Links + sequence_expand0->LinksFrom({sequence_expand0_in}) + .LinksTo({sequence_expand0_out}); + sequence_expand1->LinksFrom({sequence_expand1_in}) + .LinksTo({sequence_expand1_out}); + concat->LinksFrom({sequence_expand0_out, sequence_expand1_out, concat_in0}) + .LinksTo({concat_out}); + return concat_out; +} + +PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) { + PDNode* fc_w = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && // basic + VarLinksToOp(x, "mul") && // link + x->Var()->Proto()->persistable(); // is a parameter + }, + "fc_w"); + + PDNode* mul_out = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && // basic + VarLinksFromOp(x, "mul") && // link + VarLinksToOp(x, "elementwise_add") && // + !x->Var()->Proto()->persistable(); // is a parameter + }, + "mul_out"); + + PDNode* fc_mul = pattern->NewNode( + [](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "mul"; // basic + }, + "fc_mul"); + + PDNode* fc_bias = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && // basic + VarLinksToOp(x, "elementwise_add") && // link + x->Var()->Proto()->persistable(); // is a parameter + }, + "fc_bias"); + + PDNode* elementwise_add = pattern->NewNode( + [](Node* x) { + return x && x->IsOp() && x->Op()->Type() == "elementwise_add"; + }, + "elementwise_add"); + + PDNode* add_out = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && // basic + VarLinksFromOp(x, "elementwise_add") && // link + !x->Var()->Proto()->persistable(); // is a parameter + }, + "add_out"); + + std::set acts({"sigmoid", "tanh", "relu", "identity"}); + PDNode* act = pattern->NewNode( + [=](Node* x) { + return x && x->IsOp() && acts.count(x->Op()->Type()); + + }, + "act"); + + PDNode* fc_out = pattern->NewNode( + [](Node* x) { + return x && x->IsVar() && // basic + !x->Var()->Proto()->persistable(); // is a parameter + }, + "fc_out"); + + fc_mul->LinksFrom({fc_w, fc_x}).LinksTo({mul_out}); + elementwise_add->LinksFrom({mul_out, fc_bias}).LinksTo({add_out}); + act->LinksFrom({add_out}).LinksTo({fc_out}); + return fc_out; +} + +std::unique_ptr SeqConcatFcFusePass::ApplyImpl( + std::unique_ptr graph) const { + FusePassBase::Init(graph.get()); + GraphPatternDetector detector; + auto* pattern = detector.mutable_pattern(); + auto* concat_out = BuildSeqExpandConcatPattern(pattern); + BuildFCPattern(pattern, concat_out); + +#define GET_NODE(id, pattern) \ + PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \ + "pattern has no Node called %s", #id); \ + auto* id = subgraph.at(pattern.RetriveNode(#id)); \ + PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); + + detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "get one concat pattern"; + // fc + GET_NODE(fc_w, detector.pattern()); + GET_NODE(fc_bias, detector.pattern()); + GET_NODE(act, detector.pattern()); + GET_NODE(fc_out, detector.pattern()); + + // concat + GET_NODE(concat_in0, detector.pattern()); + GET_NODE(sequence_expand0_in, detector.pattern()); + GET_NODE(sequence_expand1_in, detector.pattern()); + + OpDesc op_desc; + op_desc.SetType("fusion_seqexpand_concat_fc"); + op_desc.SetInput("X", {concat_in0->Name(), sequence_expand0_in->Name(), + sequence_expand1_in->Name()}); + op_desc.SetInput("FCWeight", {fc_w->Name()}); + op_desc.SetInput("FCBias", {fc_bias->Name()}); + const std::string fc_out_tmp = fc_out->Name() + ".tmp"; + param_scope()->Var(fc_out_tmp)->GetMutable(); + op_desc.SetOutput("FCOut", {fc_out_tmp}); + op_desc.SetOutput("Out", {fc_out->Name()}); + op_desc.SetAttr("fc_activation", act->Op()->Type()); + + auto* op_node = graph->CreateOpNode(&op_desc); +// Add links +#define NODE_LINKS(a, b) \ + a->outputs.push_back(b); \ + b->inputs.push_back(a); + NODE_LINKS(fc_w, op_node); + NODE_LINKS(fc_bias, op_node); + NODE_LINKS(concat_in0, op_node); + NODE_LINKS(sequence_expand0_in, op_node); + NODE_LINKS(sequence_expand1_in, op_node); + NODE_LINKS(op_node, fc_out); + + // Clean nodes. + std::unordered_set marked_nodes; + for (auto& item : subgraph) { + marked_nodes.insert(item.second); + } + marked_nodes.erase(fc_w); + marked_nodes.erase(fc_bias); + marked_nodes.erase(concat_in0); + marked_nodes.erase(sequence_expand0_in); + marked_nodes.erase(sequence_expand1_in); + marked_nodes.erase(fc_out); + + GraphSafeRemoveNodes(graph, marked_nodes); + }); + + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(seq_concat_fc_fuse_pass, + paddle::framework::ir::SeqConcatFcFusePass); diff --git a/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..9f5fd1a29adf918806d8f30097d8c7f002f48f3e --- /dev/null +++ b/paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h @@ -0,0 +1,33 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class SeqConcatFcFusePass : public FusePassBase { + public: + virtual ~SeqConcatFcFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index 4feaed2b0d9cdec735bd3fadc98aa2bad715c209..779ede5e460d0ceb6fd404c4a32374f9f9d92088 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -1,5 +1,8 @@ cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass) -cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc +set(analysis_deps + framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor) + +cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc analyzer.cc helper.cc # passes @@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph tensorrt_subgraph_node_mark_pass.cc fluid_to_ir_pass.cc model_store_pass.cc - DEPS framework_proto proto_desc ir_pass_manager graph pass) + DEPS ${analysis_deps}) cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis) -cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) +cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) @@ -31,7 +34,7 @@ function (inference_analysis_test TARGET) endif() cc_test(${TARGET} SRCS "${analysis_test_SRCS}" - DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS} + DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS} ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) endif(WITH_TESTING) @@ -58,20 +61,25 @@ endif() inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis + analysis_predictor # ir fc_fuse_pass + fc_lstm_fuse_pass + seq_concat_fc_fuse_pass graph_viz_pass infer_clean_graph_pass - graph_pattern_detecter - infer_clean_graph_pass + graph_pattern_detector + infer_clean_graph_pass + attention_lstm_fuse_pass + paddle_inference_api pass ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) -inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) -inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc) +inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api) +inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 0d94ccb64e024dbdef20c1cc16aab76bc5db928c..05b606cd0fb1802dfe815b3242813f42264f6366 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager { Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } void Analyzer::Run(Argument* argument) { + // Ungly support fluid-to-ir-pass + argument->Set(kFluidToIrPassesAttr, + new std::vector({ + // Manual update the passes here. + "graph_viz_pass", // + "infer_clean_graph_pass", "graph_viz_pass", // + "attention_lstm_fuse_pass", "graph_viz_pass", // + "fc_lstm_fuse_pass", "graph_viz_pass", // + "seq_concat_fc_fuse_pass", "graph_viz_pass", // + "fc_fuse_pass", "graph_viz_pass" // + + })); + for (auto& x : data_) { PADDLE_ENFORCE(x->Initialize(argument)); x->RunAll(); diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index baa7600283a9bc0b81833b419a2ea64692ed2203..263fbb044902e886c357835ab298b4f646c7a3ed 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/platform/profiler.h" DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); @@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path, const std::string &data_path, int batch_size, bool use_analysis, bool activate_ir, int num_times = 1) { - FLAGS_IA_enable_ir = activate_ir; - FLAGS_IA_enable_tensorrt_subgraph_engine = false; - FLAGS_IA_output_storage_path = "./analysis.out"; - - std::string model_out; - if (use_analysis) { - Argument argument(model_path); - argument.model_output_store_path.reset(new std::string("./analysis.out")); - - Analyzer analyzer; - analyzer.Run(&argument); - - // Should get the transformed model stored to ./analysis.out - model_out = "./analysis.out"; - ASSERT_TRUE(PathExists(model_out)); - } else { - model_out = FLAGS_infer_ditu_rnn_model; - } - NativeConfig config; - config.prog_file = model_out + "/__model__"; - config.param_file = model_out + "/param"; + config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__"; + config.param_file = FLAGS_infer_ditu_rnn_model + "/param"; config.use_gpu = false; config.device = 0; config.specify_input_name = true; - auto predictor = + auto base_predictor = CreatePaddlePredictor(config); + auto predictor = + CreatePaddlePredictor(config); std::vector input_slots; DataRecord data(data_path, batch_size); // Prepare inputs. PrepareInputs(&input_slots, &data, batch_size); - std::vector outputs; + std::vector outputs, base_outputs; + + base_predictor->Run(input_slots, &base_outputs); Timer timer; timer.tic(); @@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path, << ", latency: " << timer.toc() / num_times << "ms"; LOG(INFO) << "====================================="; - for (auto &out : outputs) { + PADDLE_ENFORCE_GT(outputs.size(), 0); + PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size()); + for (size_t i = 0; i < outputs.size(); i++) { + auto &out = outputs[i]; + auto &base_out = base_outputs[i]; size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, [](int a, int b) { return a * b; }); + size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(), + 1, [](int a, int b) { return a * b; }); + PADDLE_ENFORCE_EQ(size, size1); + PADDLE_ENFORCE_GT(size, 0); float *data = static_cast(out.data.data()); - for (size_t i = 0; - i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size); - i++) { - EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3); + float *base_data = static_cast(base_out.data.data()); + for (size_t i = 0; i < size; i++) { + EXPECT_NEAR(data[i], base_data[i], 1e-3); } } } -// Turn on the IR pass supportion, run a real inference and check the result. -TEST(Analyzer, SupportIRPass) { - FLAGS_IA_enable_ir = true; - FLAGS_IA_enable_tensorrt_subgraph_engine = false; - FLAGS_IA_output_storage_path = "./analysis.out"; - - Argument argument(FLAGS_inference_model_dir); - argument.model_output_store_path.reset(new std::string("./analysis.out")); - - Analyzer analyzer; - analyzer.Run(&argument); - - // Should get the transformed model stored to ./analysis.out - ASSERT_TRUE(PathExists("./analysis.out")); - - // Inference from this path. - TestWord2vecPrediction("./analysis.out"); -} - // Directly infer with the original model. TEST(Analyzer, DituRNN_without_analysis) { TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, @@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) { } // namespace paddle USE_PASS(fc_fuse_pass); +USE_PASS(seq_concat_fc_fuse_pass); +USE_PASS(fc_lstm_fuse_pass); USE_PASS(graph_viz_pass); USE_PASS(infer_clean_graph_pass); +USE_PASS(attention_lstm_fuse_pass); diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a17d6281a2976f0600c7ce94c2d43e65d30de265..4401d5c5a3ca8da1c04336de4be8397334d46d9e 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -26,6 +26,7 @@ #include #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h" +#include "paddle/fluid/platform/variant.h" namespace paddle { namespace inference { @@ -58,6 +59,46 @@ struct Argument { // The output storage path of ModelStorePass. std::unique_ptr model_output_store_path; + + // Support for any other attributes. + template + void Set(const std::string& key, T* data) { + PADDLE_ENFORCE_NOT_NULL(data); + PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key); + attrs_[key] = data; + attr_deleters_[key] = [data, key, this]() { + VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; + VLOG(3) << "argument delete attr: " << key; + delete data; + }; + } + + bool Has(const std::string& name) const { return attrs_.count(name); } + + template + T* Release(const std::string& key) { + PADDLE_ENFORCE(attrs_.count(key)); + auto* res = boost::any_cast(attrs_.at(key)); + attrs_.erase(key); + attr_deleters_.erase(key); + return res; + } + + template + T& Get(const std::string& key) { + PADDLE_ENFORCE(Has(key)); + return *boost::any_cast(attrs_.at(key)); + } + + ~Argument() { + for (auto& item : attr_deleters_) { + item.second(); + } + } + + private: + std::unordered_map attrs_; + std::unordered_map> attr_deleters_; }; #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 8c7dd146e429a7f5cd28bdd418e457e8ea5680bd..8ca402da31f52f1a68a04b5de368c9c659a3a108 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" +#include "paddle/fluid/inference/io.h" namespace paddle { namespace inference { @@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { } } + if (argument_->Has("param_scope")) { + LOG(WARNING) << "parameter changes in the scope takes effect"; + } + PADDLE_ENFORCE(argument_->transformed_program_desc.get()); } diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h index 4bf1840fdda8508b52d7274a338c5b1c95baf354..4693729cb43d7a9df96b11c4bf3064a70d1db4c3 100644 --- a/paddle/fluid/inference/analysis/dot.h +++ b/paddle/fluid/inference/analysis/dot.h @@ -29,13 +29,13 @@ namespace paddle { namespace inference { namespace analysis { +static size_t dot_node_counter{0}; + /* * A Dot template that helps to build a DOT graph definition. */ class Dot { public: - static size_t counter; - struct Attr { std::string key; std::string value; @@ -57,7 +57,7 @@ class Dot { Node(const std::string& name, const std::vector& attrs) : name(name), attrs(attrs), - id_("node_" + std::to_string(Dot::counter++)) {} + id_("node_" + std::to_string(dot_node_counter++)) {} std::string id() const { return id_; } @@ -65,6 +65,10 @@ class Dot { std::stringstream ss; CHECK(!name.empty()); ss << id_; + if (attrs.empty()) { + ss << "[label=" << '"' << name << '"' << "]"; + return ss.str(); + } for (size_t i = 0; i < attrs.size(); i++) { if (i == 0) { ss << "[label=" << '"' << name << '"' << " "; @@ -108,9 +112,11 @@ class Dot { explicit Dot(const std::vector& attrs) : attrs_(attrs) {} - void AddNode(const std::string& name, const std::vector& attrs) { - CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'"; - nodes_.emplace(name, Node{name, attrs}); + void AddNode(const std::string& id, const std::vector& attrs, + std::string label = "") { + CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'"; + if (label.empty()) label = id; + nodes_.emplace(id, Node{label, attrs}); } void AddEdge(const std::string& source, const std::string& target, diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc index 073f49752872cbb65fddc74be75ec28d4dd0bbaf..5e53fff39213b53bc78e9272a7efd26d7ee91023 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.cc @@ -13,3 +13,47 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void FluidToIrPass::EnableParamModify(const std::string &model_dir, + const std::string &prog_file, + const std::string ¶m_file) { + PADDLE_ENFORCE(argument_); + argument_->Set("param_scope", new framework::Scope); + // Load parameters. + VLOG(3) << "Loading parameters from " << model_dir; + LoadParams(&argument_->Get("param_scope"), model_dir, + prog_file, param_file); +} + +bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir, + const std::string &prog_file, + const std::string ¶m_file) { + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + framework::Executor executor(place); + PADDLE_ENFORCE(argument_->origin_program_desc.get()); + framework::ProgramDesc program(*argument_->origin_program_desc); + if ((!prog_file.empty()) && (!param_file.empty())) { + LOG(INFO) << "load single model file from " << prog_file; + Load(&executor, scope, prog_file, param_file); + } else if (!dir.empty()) { + LOG(INFO) << "load from dir " << dir; + Load(&executor, scope, dir); + } else { + LOG(ERROR) << "failed to load parameters"; + return false; + } + return true; +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h index fa3f8d313bbdd6733fa3878dd7023e125b6ced36..29008105f82989f5797116e78990853880708936 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h @@ -21,12 +21,17 @@ namespace paddle { namespace inference { namespace analysis { +static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__"; + class FluidToIrPass final : public DataFlowGraphPass { public: FluidToIrPass() = default; bool Initialize(Argument *argument) override { ANALYSIS_ARGUMENT_CHECK_FIELD(argument); + PADDLE_ENFORCE(argument->Has(kFluidToIrPassesAttr), + "argument need the attr %s", kFluidToIrPassesAttr); + argument_ = argument; if (argument->origin_program_desc) { LOG(WARNING) << "argument's origin_program_desc is already set, might " "duplicate called"; @@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass { if (!argument->main_dfg) { argument->main_dfg.reset(new DataFlowGraph); } - // Persist the ProgramDesc in graph's attribute. The IR graph just keep the - // address, will segfault if the original ProgramDesc destroys. - auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer(); - ir_program_p = new framework::ProgramDesc(program); + argument->Set("ir_program_desc", new framework::ProgramDesc(program)); + + LOG(INFO) << "Loading parameters"; + // Load parameters to argument if needed. + if (argument->fluid_model_dir || (argument->fluid_model_program_path && + argument->fluid_model_param_path)) { +#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : ""; + SAFE_GET(fluid_model_dir); + SAFE_GET(fluid_model_program_path); + SAFE_GET(fluid_model_param_path); +#undef SAFE_GET + EnableParamModify(fluid_model_dir, fluid_model_program_path, + fluid_model_param_path); + } - argument_ = argument; return true; } @@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass { void Run(DataFlowGraph *graph) override { // Call all the IR Passes - IRPassManager ir_passes(*static_cast( - argument_->main_dfg->Attr("ir_program_desc").Pointer())); - ir_passes.Apply(std::vector( - {// Manual update the passes here. - "graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass", - "fc_fuse_pass", "graph_viz_pass"})); + IRPassManager ir_passes( + argument_->Get("ir_program_desc"), nullptr); + // Pass the scope from analysis to IR if needed. + if (argument_->Has("param_scope")) { + // Here the address is passed, attention that IR doesn't own the scope, so + // the real scope in analysis should live during the IR phase. + ir_passes.graph().Set( + "param_scope", new framework::Scope *( + &argument_->Get("param_scope"))); + } + + const auto &ir_passes_to_apply = + argument_->Get>(kFluidToIrPassesAttr); + ir_passes.Apply(ir_passes_to_apply); PADDLE_ENFORCE(argument_->main_dfg.get()); argument_->main_dfg->Build(ir_passes.graph()); - // PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected()); } + void EnableParamModify(const std::string &model_dir, + const std::string &prog_file, + const std::string ¶m_file); + std::string repr() const override { return "fluid-to-ir-pass"; } + private: + // Load parameters from a single file or from a directory. + bool LoadParams(framework::Scope *scope, const std::string &dir, + const std::string &prog_file, const std::string ¶m_file); + private: Argument *argument_{nullptr}; }; diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc index af934f261baa3807059ce6ab036545594630df58..6a13c60e7b2ebf645b12d5ddf83ef6ab3a2e83bd 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass_tester.cc @@ -24,6 +24,8 @@ namespace analysis { TEST(FluidToIrPass, Test) { FluidToIrPass pass; Argument argument(FLAGS_inference_model_dir); + argument.Set(kFluidToIrPassesAttr, + new std::vector({"infer_clean_graph_pass"})); pass.Initialize(&argument); pass.Run(argument.main_dfg.get()); } @@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) { } // namespace inference } // namespace paddle -USE_PASS(fc_fuse_pass); USE_PASS(graph_viz_pass); USE_PASS(infer_clean_graph_pass); +USE_PASS(attention_lstm_fuse_pass); +USE_PASS(fc_lstm_fuse_pass); +USE_PASS(seq_concat_fc_fuse_pass); +USE_PASS(fc_fuse_pass); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index d849b637bcf3fe3944ad11680bbe041e19a71e24..5da5241e49a2f7c8c0951e1a3c31784b8af65134 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -14,20 +14,24 @@ #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/scope.h" namespace paddle { namespace inference { namespace analysis { -IRPassManager::IRPassManager(const ProgramDesc& program) { +IRPassManager::IRPassManager(const ProgramDesc &program, + framework::Scope *scope) + : program_(program) { graph_.reset(new framework::ir::Graph(program)); + if (scope) graph_->Set("param_scope", new framework::Scope *(scope)); } -void IRPassManager::Apply(const std::vector& passes) { - graph_->Set("graph_viz_path", new std::string("./1.dot")); +void IRPassManager::Apply(const std::vector &passes) { // Apply all the passes std::string pre_pass; - for (const std::string& pass_name : passes) { + for (const std::string &pass_name : passes) { LOG(WARNING) << "Running IR pass [" << pass_name << "]"; auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); if (pass_name == "graph_viz_pass") { diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.h b/paddle/fluid/inference/analysis/ir_pass_manager.h index 3338e37ecf1c591a631fd829a05b07e562af703e..bb230283b7c2cc783d0b68ea0aa3cca1cabc75e6 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.h +++ b/paddle/fluid/inference/analysis/ir_pass_manager.h @@ -23,6 +23,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" namespace paddle { namespace inference { @@ -31,14 +32,15 @@ using framework::ProgramDesc; class IRPassManager final { public: - IRPassManager(const ProgramDesc& program); + IRPassManager(const ProgramDesc &program, framework::Scope *scope); - void Apply(const std::vector& passes); + void Apply(const std::vector &passes); - framework::ir::Graph& graph() const { return *graph_; } + framework::ir::Graph &graph() const { return *graph_; } private: std::unique_ptr graph_; + ProgramDesc program_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/pass_manager.cc b/paddle/fluid/inference/analysis/pass_manager.cc index cfdca33882ea00a28e3ea51ca5fd77ec9605bf3a..ff5ec94265a4f05c1294ad6c8ac5f86c249b84b6 100644 --- a/paddle/fluid/inference/analysis/pass_manager.cc +++ b/paddle/fluid/inference/analysis/pass_manager.cc @@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) { void DfgPassManager::RunAll() { PADDLE_ENFORCE(argument_); - LOG(INFO) << "Total " << data_.size() << " passes"; + LOG(INFO) << "Total " << data_.size() << " Analysys passes"; for (auto& pass : data_) { - LOG(WARNING) << "Running pass [" << pass->repr() << "]"; + LOG(WARNING) << "Running Analysis pass [" << pass->repr() << "]"; pass->Run(argument_->main_dfg.get()); } } diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 0ca1af455ca10fa6995ad3a1c33825108a3fd7ad..adfe4392448557a30cd834022b9a5d21d9086b95 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -20,7 +20,7 @@ endif(APPLE) set(inference_deps paddle_inference_api paddle_fluid_api analysis pass ir_pass_manager graph_viz_pass fc_fuse_pass - infer_clean_graph_pass + infer_clean_graph_pass ) if(WITH_GPU AND TENSORRT_FOUND) @@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME) endif(WITH_TESTING) endfunction(inference_api_test) -cc_library(paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor) +cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor) +cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api) cc_test(test_paddle_inference_api SRCS api_tester.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b29b233822330e3c1441793ce036b9b9278721b --- /dev/null +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -0,0 +1,165 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/analysis/analyzer.h" +#include "paddle/fluid/inference/api/api_impl.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/utils/singleton.h" + +namespace paddle { + +using inference::analysis::Argument; +using inference::Singleton; +using inference::analysis::Analyzer; +using framework::proto::ProgramDesc; + +/* This predictor is based on the original native predictor with IR and Analysis + * support. It will optimize IR and Parameters in the runtime. + * TODO(Superjomn) Replace the Navive predictor? + */ +class AnalysisPredictor : public NativePaddlePredictor { + public: + explicit AnalysisPredictor(const NativeConfig& config) + : NativePaddlePredictor(config), config_(config) {} + + bool Init(const std::shared_ptr& parent_scope) { + VLOG(3) << "Predictor::init()"; + if (config_.use_gpu) { + place_ = paddle::platform::CUDAPlace(config_.device); + } else { + place_ = paddle::platform::CPUPlace(); + } + PADDLE_ENFORCE(!parent_scope); + if (parent_scope) { + scope_ = parent_scope; + sub_scope_ = &(parent_scope->NewScope()); + } else { + paddle::framework::InitDevices(false); + scope_.reset(new paddle::framework::Scope()); + } + + executor_.reset(new paddle::framework::Executor(place_)); + + // Initialize the inference program + if (!config_.model_dir.empty()) { + // Parameters are saved in separate files sited in + // the specified `dirname`. + inference_program_ = paddle::inference::Load( + executor_.get(), scope_.get(), config_.model_dir); + } else if (!config_.prog_file.empty() && !config_.param_file.empty()) { + // All parameters are saved in a single file. + // The file names should be consistent with that used + // in Python API `fluid.io.save_inference_model`. + inference_program_ = paddle::inference::Load( + executor_.get(), scope_.get(), config_.prog_file, config_.param_file); + } else { + LOG(ERROR) << "fail to load inference model."; + return false; + } + + OptimizeInferenceProgram(); + ctx_ = executor_->Prepare(*inference_program_, 0); + + VLOG(5) << "to create variables"; + PADDLE_ENFORCE(scope_.get()); + executor_->CreateVariables(*inference_program_, + sub_scope_ ? sub_scope_ : scope_.get(), 0); + + // Get the feed_target_names and fetch_target_names + feed_target_names_ = inference_program_->GetFeedTargetNames(); + fetch_target_names_ = inference_program_->GetFetchTargetNames(); + return true; + } + + bool Run(const std::vector& inputs, + std::vector* output_data, + int batch_size = -1) override { + return NativePaddlePredictor::Run(inputs, output_data, batch_size); + } + + void OptimizeInferenceProgram() { + LOG(INFO) << "optimize begin"; + FLAGS_IA_enable_ir = true; + FLAGS_IA_enable_tensorrt_subgraph_engine = false; + FLAGS_IA_output_storage_path = ""; // Don't output the model. + // Analyze inference_program + Argument argument; + if (!config_.model_dir.empty()) { + argument.fluid_model_dir.reset(new std::string(config_.model_dir)); + } else { + PADDLE_ENFORCE( + !config_.param_file.empty(), + "Either model_dir or (param_file, prog_file) should be set."); + PADDLE_ENFORCE(!config_.prog_file.empty()); + argument.fluid_model_program_path.reset( + new std::string(config_.prog_file)); + argument.fluid_model_param_path.reset( + new std::string(config_.param_file)); + } + argument.origin_program_desc.reset( + new ProgramDesc(*inference_program_->Proto())); + Singleton::Global().Run(&argument); + CHECK(argument.transformed_program_desc); + VLOG(5) << "to prepare executor"; + // LOG(INFO) << "transformed_parogram_desc " << + // argument.transformed_program_desc->DebugString(); + inference_program_.reset( + new framework::ProgramDesc(*argument.transformed_program_desc)); + PADDLE_ENFORCE(argument.Has("param_scope")); + // Update scope. + scope_.reset(argument.Release("param_scope")); + LOG(INFO) << "optimize end =="; + } + + private: + NativeConfig config_; +}; + +template <> +std::unique_ptr CreatePaddlePredictor< + NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) { + VLOG(3) << "create NativePredictor"; + if (config.use_gpu) { + // 1. GPU memeroy + PADDLE_ENFORCE_GT( + config.fraction_of_gpu_memory, 0.f, + "fraction_of_gpu_memory in the config should be set to range (0., 1.]"); + PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device); + std::vector flags; + if (config.fraction_of_gpu_memory >= 0.0f || + config.fraction_of_gpu_memory <= 0.95f) { + flags.push_back("dummpy"); + std::string flag = "--fraction_of_gpu_memory_to_use=" + + std::to_string(config.fraction_of_gpu_memory); + flags.push_back(flag); + VLOG(3) << "set flag: " << flag; + framework::InitGflags(flags); + } + } + + std::unique_ptr predictor(new AnalysisPredictor(config)); + if (!dynamic_cast(predictor.get())->Init(nullptr)) { + return nullptr; + } + return predictor; +} + +} // namespace paddle + +USE_PASS(fc_fuse_pass); +USE_PASS(graph_viz_pass); +USE_PASS(infer_clean_graph_pass); diff --git a/paddle/fluid/inference/api/helper.cc b/paddle/fluid/inference/api/helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cc491e10d691a206dd903b78c0ea570741da44c --- /dev/null +++ b/paddle/fluid/inference/api/helper.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/api/helper.h" + +namespace paddle { +namespace inference { + +template <> +std::string to_string>( + const std::vector> &vec) { + std::stringstream ss; + for (const auto &piece : vec) { + ss << to_string(piece) << "\n"; + } + return ss.str(); +} + +template <> +std::string to_string>>( + const std::vector>> &vec) { + std::stringstream ss; + for (const auto &line : vec) { + for (const auto &rcd : line) { + ss << to_string(rcd) << ";\t"; + } + ss << '\n'; + } + return ss.str(); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 2c166cc0622f68e6d527005795c21236ccf43c33..e44b1b74bc385c015fa6efcebac05359a810cbc1 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -44,7 +44,8 @@ class Timer { } }; -void split(const std::string &str, char sep, std::vector *pieces) { +static void split(const std::string &str, char sep, + std::vector *pieces) { pieces->clear(); if (str.empty()) { return; @@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector *pieces) { pieces->push_back(str.substr(pos)); } } -void split_to_float(const std::string &str, char sep, std::vector *fs) { +static void split_to_float(const std::string &str, char sep, + std::vector *fs) { std::vector pieces; split(str, sep, &pieces); std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs), @@ -76,27 +78,14 @@ std::string to_string(const std::vector &vec) { } template <> std::string to_string>( - const std::vector> &vec) { - std::stringstream ss; - for (const auto &piece : vec) { - ss << to_string(piece) << "\n"; - } - return ss.str(); -} + const std::vector> &vec); + template <> std::string to_string>>( - const std::vector>> &vec) { - std::stringstream ss; - for (const auto &line : vec) { - for (const auto &rcd : line) { - ss << to_string(rcd) << ";\t"; - } - ss << '\n'; - } - return ss.str(); -} + const std::vector>> &vec); + // clang-format off -void TensorAssignData(PaddleTensor *tensor, const std::vector> &data) { +static void TensorAssignData(PaddleTensor *tensor, const std::vector> &data) { // Assign buffer int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; }); tensor->data.Resize(sizeof(float) * dim); diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 36fd0727aa7beef4a06a5f2e63ec0c43583ddf84..1baa64c249f291ec1bc874be5031abe6d4368274 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -77,6 +77,7 @@ enum class PaddleEngineKind { kNative = 0, // Use the native Fluid facility. kAnakin, // Use Anakin for inference. kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT. + kAnalysis // TODO(Superjomn) support following engines latter. // kTensorRT, // Use TensorRT for inference. // kAutoMixedAnakin, // Automatically mix Fluid with Anakin. diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc index 181868977dd8f2568486ed0c4e1f260a69795896..cef7b2a7e3a29da05628d7540f5545dc9adda27e 100644 --- a/paddle/fluid/inference/io.cc +++ b/paddle/fluid/inference/io.cc @@ -143,5 +143,21 @@ std::unique_ptr Load( return main_program; } +void SaveVars(const framework::Scope& scope, + const std::vector& vars, const std::string& dirname, + bool predicate) { + framework::ProgramDesc prog; + auto* block = prog.MutableBlock(0); + auto* op = block->AppendOp(); + op->SetType("save_combine"); + op->SetInput("X", vars); + op->SetAttr("file_path", dirname + "/param"); + op->CheckAttrs(); + + platform::CPUPlace place; + framework::Executor exe(place); + exe.Run(prog, const_cast(&scope), 0, true, true); +} + } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h index 01b50b3670cb9da2e0be232a61ea6129dd83aa20..ab492577c1476abee30d6dd1c740394391e5a93a 100644 --- a/paddle/fluid/inference/io.h +++ b/paddle/fluid/inference/io.h @@ -41,5 +41,10 @@ std::unique_ptr Load(framework::Executor* executor, const std::string& prog_filename, const std::string& param_filename); +// Save the variables from a scope to disk. +void SaveVars(const framework::Scope& scope, + const std::vector& vars, const std::string& dirname, + bool predicate = true); + } // namespace inference } // namespace paddle diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 8bab37c5830dfdcd5d6ccf1cc049387b496b0d04..a02128c5a54c80ca7ccf9db347cd53f28bbb50f8 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { const int D = w_dims[1] / 4; PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(w_dims[0], D + M, - "LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); + "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D); auto b_dims = ctx->GetInputDim("LSTMBias"); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); diff --git a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc index 90aba5fe89d79b8aa7008bc803b0a646e93bd0fc..0cd3d3887cf5167c779a8b20442fdb458cd7eab4 100644 --- a/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc +++ b/paddle/fluid/operators/fusion_seqexpand_concat_fc_op.cc @@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape( "FC height should be sum of all inputs width."); if (ctx->HasInput("FCBias")) { auto b_dims = ctx->GetInputDim("FCBias"); - PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2."); - PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D); - PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D); + PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2, + "b_dims should be 1 or 2, get %d", b_dims.size()); + if (b_dims.size() == 1) { + PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D); + } else { + PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D); + PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D); + } } ctx->SetOutputDim("Out", {ins_dims[0][0], D}); diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 020ce4d6f59412490657767a096f1ce185287864..4c99f4be321160caf0ee2f89a655bdfb933408e3 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) { } catch (const std::exception &exp) { LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; } -#else - LOG(WARNING) - << "'CUDA' is not supported, Please re-compile with WITH_GPU option"; #endif InitDevices(init_p2p, devices); } @@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector devices) { } catch (const std::exception &exp) { LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; } -#else - LOG(WARNING) - << "'CUDA' is not supported, Please re-compile with WITH_GPU option"; #endif for (size_t i = 0; i < devices.size(); ++i) {