未验证 提交 902f19b4 编写于 作者: Y Yan Chunwei 提交者: GitHub

fea/fuse attention lstm simplify.with fusion lstm.with sequnce expand (#13006)

上级 55f240ba
...@@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper) ...@@ -5,14 +5,16 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper) cc_library(graph_to_program_pass SRCS graph_to_program_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter) cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(attention_lstm_fuse_pass SRCS attention_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass) cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
cc_library(fc_lstm_fuse_pass SRCS fc_lstm_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_library(seq_concat_fc_fuse_pass SRCS seq_concat_fc_fuse_pass.cc DEPS graph graph_pattern_detector)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detector graph pass graph_traits framework_proto)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace framework {
namespace ir {
struct Param {
std::string X = "concat_0.tmp_0";
std::string C0 = "cell_init";
std::string H0 = "hidden_init";
std::string AttentionWeight = "attention_fc.w_0";
std::string AttentionBias = "attention_fc.b_0";
std::string AttentionScalar = "attention_output.w_0";
std::string AttentionScalarBias = "attention_output.b_0";
std::string LSTMWeight = "attention_w.new";
std::string LSTMBias = "attention_b.new";
std::string Hidden = "array_to_lod_tensor_0.tmp_0";
std::string Cell = "at.cell.new";
std::string AttentionedX = "at.x.new";
std::string AttentionFCOut = "at.fc.new";
std::string LSTMX = "at.lstmx.new";
std::string LSTMOUT = "at.lstmout.new";
};
void PrepareParameters(Graph* graph, const Param& param);
void FindWhileOp(Graph* graph) {
GraphPatternDetector gpd;
std::unordered_set<int> fused_external_ops(
{35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48,
57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51});
gpd.mutable_pattern()->NewNode(
[&](Node* n) { return fused_external_ops.count(n->id()); }, "while");
if (!graph->Has(kGraphvizMarkedNodeAttr)) {
graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
}
auto& marked_nodes =
graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* while_pat_node = gpd.pattern().RetriveNode("while");
auto* while_node = subgraph.at(while_pat_node);
marked_nodes.insert(while_node);
};
gpd(graph, handle);
Param param;
// Add AttentionLSTM node
OpDesc op_desc;
op_desc.SetType("attention_lstm");
#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x});
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x});
OP_SET_IN(X);
OP_SET_IN(C0);
OP_SET_IN(H0);
OP_SET_IN(AttentionWeight);
OP_SET_IN(AttentionBias);
OP_SET_IN(AttentionScalar);
OP_SET_IN(AttentionScalarBias);
OP_SET_IN(LSTMWeight);
OP_SET_IN(LSTMBias);
OP_SET_OUT(Hidden);
OP_SET_OUT(Cell);
OP_SET_OUT(AttentionedX);
OP_SET_OUT(AttentionFCOut);
OP_SET_OUT(LSTMX);
OP_SET_OUT(LSTMOUT);
#undef OP_SET_IN
#undef OP_SET_OUT
auto* X = graph->RetriveNode(34);
auto* LSTMOUT = graph->RetriveNode(81);
auto* cell_init = graph->RetriveNode(6);
auto* hidden_init = graph->RetriveNode(8);
#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);
auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param);
LINK_TO(X, lstm_op);
LINK_TO(cell_init, lstm_op);
LINK_TO(hidden_init, lstm_op);
LINK_TO(lstm_op, LSTMOUT);
GraphSafeRemoveNodes(graph, marked_nodes);
}
#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x);
#define CHECK_P2(x0, x1) \
CHECK_P1(x0); \
CHECK_P1(x1);
#define CHECK_P3(x0, x1, x2) \
CHECK_P2(x0, x1); \
CHECK_P1(x2);
#define CHECK_P4(x0, x1, x2, x3) \
CHECK_P3(x0, x1, x2); \
CHECK_P1(x3);
#define CHECK_P5(x0, x1, x2, x3, x4) \
CHECK_P4(x0, x1, x2, x3); \
CHECK_P1(x4);
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out);
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out);
void PrepareParameters(Graph* graph, const Param& param) {
// Check parameters
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
// Create new parameters.
scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>();
scope->Var(param.LSTMBias)->GetMutable<LoDTensor>();
scope->Var(param.Hidden)->GetMutable<LoDTensor>();
scope->Var(param.Cell)->GetMutable<LoDTensor>();
scope->Var(param.AttentionedX)->GetMutable<LoDTensor>();
scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>();
scope->Var(param.LSTMX)->GetMutable<LoDTensor>();
scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>();
#define GATE_W(name__) \
auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \
auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \
auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \
CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \
VLOG(4) << #name__ "_w0" \
<< " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_w1" \
<< " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \
VLOG(4) << #name__ "_b0" \
<< " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \
auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \
auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \
auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>();
GATE_W(forget);
GATE_W(input);
GATE_W(output);
GATE_W(c);
#undef GATE_W
auto* attention_fc_w = scope->FindVar("attention_fc.w_0");
auto* attention_fc_b = scope->FindVar("attention_fc.b_0");
auto* attention_output_w = scope->FindVar("attention_output.w_0");
auto* attention_output_b = scope->FindVar("attention_output.b_0");
CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w,
attention_output_b);
auto* lstm_weight = scope->Var(param.LSTMWeight);
auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>();
auto* lstm_bias = scope->Var(param.LSTMBias);
auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>();
// reshape attention_bias
auto* attention_bias_t =
scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1);
attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]}));
auto* attention_scalar_bias_t =
scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>();
attention_scalar_bias_t->Resize(
make_ddim({1, attention_scalar_bias_t->dims()[0]}));
PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t,
W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t,
lstm_weight_t);
PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t,
lstm_bias_t);
}
// Prepare parameters
void PrepareLSTMWeight(const LoDTensor& W_forget_w0,
const LoDTensor& W_forget_w1,
const LoDTensor& W_input_w0, const LoDTensor& W_input_w1,
const LoDTensor& W_output_w0,
const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0,
const LoDTensor& W_cell_w1, LoDTensor* out) {
int D = W_forget_w0.dims()[0];
int M = W_forget_w1.dims()[0];
out->Resize(make_ddim({D + M, 4 * D}));
VLOG(3) << "LSTMWeight resized to " << out->dims();
float* out_data = out->mutable_data<float>(platform::CPUPlace());
std::array<const float*, 4> tensors(
{W_forget_w0.data<float>(), W_input_w0.data<float>(),
W_output_w0.data<float>(), W_cell_w0.data<float>()});
std::array<const float*, 4> tensors1(
{W_forget_w1.data<float>(), W_input_w1.data<float>(),
W_output_w1.data<float>(), W_cell_w1.data<float>()});
for (int row = 0; row < D; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * row + D * col;
const float* src = tensors[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
for (int row = 0; row < M; row++) {
for (int col = 0; col < 4; col++) {
float* dst = out_data + 4 * D * (D + row) + D * col;
const float* src = tensors1[col] + D * row;
memcpy(dst, src, D * sizeof(float));
}
}
}
void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input,
const LoDTensor& B_output, const LoDTensor& B_cell,
LoDTensor* out) {
std::array<const float*, 4> tensors(
{B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(),
B_cell.data<float>()});
PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1);
int D = B_forget.dims()[0];
out->Resize(make_ddim({1, 4 * D}));
auto* out_data = out->mutable_data<float>(platform::CPUPlace());
for (size_t i = 0; i < tensors.size(); i++) {
memcpy(out_data + D * i, tensors[i], D * sizeof(float));
}
}
// Parameters
std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PDPattern external_pattern, subblock_pattern;
FindWhileOp(graph.get());
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(attention_lstm_fuse_pass,
paddle::framework::ir::AttentionLSTMFusePass);
...@@ -12,12 +12,19 @@ ...@@ -12,12 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/dot.h" #pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle { namespace paddle {
namespace inference { namespace framework {
namespace analysis { namespace ir {
size_t Dot::counter = 0;
} // namespace analysis class AttentionLSTMFusePass : public FusePassBase {
} // namespace inference protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle } // namespace paddle
...@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) { ...@@ -100,12 +100,10 @@ void BuildFCPattern(PDPattern* pattern) {
}, },
"elementwise_add_out"); "elementwise_add_out");
pattern->AddEdge(mul_parameter_var, mul_op); mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
pattern->AddEdge(mul_tmp_input_var, mul_op); .LinksTo({mul_out_var});
pattern->AddEdge(mul_op, mul_out_var); elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
pattern->AddEdge(mul_out_var, elementwise_add_op); .LinksTo({elementwise_add_out_var});
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
} }
// Replace the node `from` in the links to `to` // Replace the node `from` in the links to `to`
...@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -125,7 +123,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
GraphPatternDetecter gpd; GraphPatternDetector gpd;
BuildFCPattern(gpd.mutable_pattern()); BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \ #define GET_NODE(id) \
...@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -134,7 +132,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \ auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the // Currently, there is no FC op available, so I will just simulate the
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::unordered_set<int> fused_ops({// first lstm
13, 15, 16,
// second lstm
23, 25, 26});
pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
"any_node");
std::unordered_set<Node*> marked_nodes;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
marked_nodes.insert(id);
};
gpd(graph.get(), handler);
// Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
int bias, int hidden, int cell, int xx) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(input);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(cell);
GET_NODE(xx);
GET_NODE(lstm);
OpDesc op_desc;
op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()});
SET_IN(X, input);
SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h);
SET_IN(Bias, bias);
#undef GET_NODE
#undef SET_IN
LOG(INFO) << "hidden_n: " << hidden_n->Name();
LOG(INFO) << "cell: " << cell_n->Name();
LOG(INFO) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()});
op_desc.SetOutput("XX", {xx_n->Name()});
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", false);
auto* op = graph->CreateOpNode(&op_desc);
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
LINK_TO(input_n, op);
LINK_TO(weight_x_n, op);
LINK_TO(weight_h_n, op);
LINK_TO(bias_n, op);
LINK_TO(op, hidden_n);
#undef LINK_TO
return op;
};
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
// remove all the nodes
for (auto* node : marked_nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (marked_nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class FCLstmFusePass : public Pass {
public:
virtual ~FCLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kParamScopeAttr[] = "param_scope";
class FusePassBase : public Pass {
public:
void Init(Graph* graph) const { graph_ = graph; }
Scope* param_scope() const {
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
return graph_->Get<framework::Scope*>(kParamScopeAttr);
}
virtual ~FusePassBase() {}
protected:
mutable Graph* graph_;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -99,13 +99,13 @@ class Graph { ...@@ -99,13 +99,13 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc); PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc, node_count_++));
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc); PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc, node_count_++));
} }
// Create a control dependency var that connects 2 operations. The // Create a control dependency var that connects 2 operations. The
...@@ -115,13 +115,14 @@ class Graph { ...@@ -115,13 +115,14 @@ class Graph {
// TODO(panyx0718): control var name should be really unique. // TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); return AddNode(
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
} }
// A more free style way of creating a graph node. Mostly use for test // A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible. // or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type)); return AddNode(new ir::Node(name, type, node_count_++));
} }
// Clear all node information of the graph and return the ownership of the // Clear all node information of the graph and return the ownership of the
...@@ -142,12 +143,20 @@ class Graph { ...@@ -142,12 +143,20 @@ class Graph {
nodes_.erase(node); nodes_.erase(node);
} }
Node *RetriveNode(int id) {
auto it = id2node_.find(id);
if (it != id2node_.end()) return it->second;
return nullptr;
}
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
nodes_[node].reset(node); nodes_[node].reset(node);
node_set_.insert(node); node_set_.insert(node);
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
id2node_[node->id()] = node;
return node; return node;
} }
...@@ -157,6 +166,8 @@ class Graph { ...@@ -157,6 +166,8 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
std::map<int, Node *> id2node_;
int node_count_{0};
}; };
bool IsControlDepVar(const ir::Node &var); bool IsControlDepVar(const ir::Node &var);
......
...@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList( ...@@ -103,10 +103,10 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
for (auto &var : n->inputs) { for (auto &var : n->inputs) {
for (auto &adj_n : var->inputs) { for (auto &adj_n : var->inputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
adj_list[n].insert(adj_n);
VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n) VLOG(4) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n) << " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var); << " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
} }
} }
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { ...@@ -34,7 +34,7 @@ PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
name); name);
} }
nodes_.emplace_back(new PDNode(std::move(teller), name)); nodes_.emplace_back(new PDNode(std::move(teller), this, name));
auto* cur = nodes_.back().get(); auto* cur = nodes_.back().get();
node_map_[name] = cur; node_map_[name] = cur;
return cur; return cur;
...@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { ...@@ -56,19 +56,22 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
edges_.emplace_back(a, b); edges_.emplace_back(a, b);
} }
void GraphPatternDetecter::operator()(Graph* graph, void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetecter::handle_t handler) { GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) return; if (!MarkPDNodesInGraph(*graph)) return;
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
LOG(INFO) << "detect " << subgraphs.size() << " subgraph matches the pattern";
int id = 0;
for (auto& g : subgraphs) { for (auto& g : subgraphs) {
LOG(INFO) << "optimizing #" << id++ << " subgraph";
handler(g, graph); handler(g, graph);
} }
} }
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph"; VLOG(4) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
...@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) { ...@@ -114,13 +117,15 @@ bool IsNodesLink(Node* a, Node* b) {
return false; return false;
} }
std::vector<GraphPatternDetecter::subgraph_t> std::vector<GraphPatternDetector::subgraph_t>
GraphPatternDetecter::DetectPatterns() { GraphPatternDetector::DetectPatterns() {
// Init empty subgraphs. // Init empty subgraphs.
std::vector<GraphPatternDetecter::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups; std::vector<HitGroup> init_groups;
PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed"); std::array<std::vector<HitGroup>, 2> bi_records;
auto* first_pnode = pattern_.edges().front().first; // PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto* first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result; if (!pdnodes2nodes_.count(first_pnode)) return result;
for (auto* node : pdnodes2nodes_[first_pnode]) { for (auto* node : pdnodes2nodes_[first_pnode]) {
HitGroup group; HitGroup group;
...@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -129,7 +134,6 @@ GraphPatternDetecter::DetectPatterns() {
} }
int step = 0; int step = 0;
std::array<std::vector<HitGroup>, 2> bi_records;
bi_records[0] = std::move(init_groups); bi_records[0] = std::move(init_groups);
// Extend a PDNode to subgraphs by deducing the connection relations defined // Extend a PDNode to subgraphs by deducing the connection relations defined
...@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -141,6 +145,7 @@ GraphPatternDetecter::DetectPatterns() {
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
auto& cur_groups = bi_records[1 - (step++ % 2)]; auto& cur_groups = bi_records[1 - (step++ % 2)];
cur_groups.clear(); cur_groups.clear();
if (pre_groups.empty()) break;
// source -> target // source -> target
for (Node* source : pdnodes2nodes_[edge.first]) { for (Node* source : pdnodes2nodes_[edge.first]) {
for (Node* target : pdnodes2nodes_[edge.second]) { for (Node* target : pdnodes2nodes_[edge.second]) {
...@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -163,7 +168,7 @@ GraphPatternDetecter::DetectPatterns() {
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
GraphPatternDetecter::subgraph_t subgraph; GraphPatternDetector::subgraph_t subgraph;
for (auto& role : group.roles) { for (auto& role : group.roles) {
subgraph.emplace(role.first, role.second); subgraph.emplace(role.first, role.second);
} }
...@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -172,10 +177,10 @@ GraphPatternDetecter::DetectPatterns() {
return result; return result;
} }
void GraphPatternDetecter::UniquePatterns( void GraphPatternDetector::UniquePatterns(
std::vector<GraphPatternDetecter::subgraph_t>* subgraphs) { std::vector<GraphPatternDetector::subgraph_t>* subgraphs) {
if (subgraphs->empty()) return; if (subgraphs->empty()) return;
std::vector<GraphPatternDetecter::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set; std::unordered_set<size_t> set;
for (auto& g : *subgraphs) { for (auto& g : *subgraphs) {
...@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns( ...@@ -192,7 +197,7 @@ void GraphPatternDetecter::UniquePatterns(
*subgraphs = result; *subgraphs = result;
} }
void GraphPatternDetecter::RemoveOverlappedMatch( void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t>* subgraphs) { std::vector<subgraph_t>* subgraphs) {
std::vector<subgraph_t> result; std::vector<subgraph_t> result;
std::unordered_set<Node*> node_set; std::unordered_set<Node*> node_set;
...@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch( ...@@ -215,6 +220,46 @@ void GraphPatternDetecter::RemoveOverlappedMatch(
*subgraphs = result; *subgraphs = result;
} }
std::string PDPattern::DotString() const {
using inference::analysis::Dot;
Dot dot;
int id = 0;
// Create Nodes
std::unordered_map<PDNode*, std::string> node2dot;
for (const auto& node : nodes()) {
std::string node_id = "Node" + std::to_string(id++);
dot.AddNode(node_id, {}, node->name());
node2dot[node.get()] = node_id;
}
// Create Edges
for (const auto& edge : edges()) {
if (!node2dot.count(edge.first) || !node2dot.count(edge.second)) {
LOG(ERROR) << "no node " << edge.first << " " << edge.second;
continue;
}
auto& src = node2dot.at(edge.first);
auto& trg = node2dot.at(edge.second);
dot.AddEdge(src, trg, {});
}
return dot.Build();
}
PDNode& PDNode::LinksTo(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(this, x);
}
return *this;
}
PDNode& PDNode::LinksFrom(const std::vector<PDNode*>& others) {
// extend outlinks.
for (PDNode* x : others) {
pattern_->AddEdge(x, this);
}
return *this;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
#include <numeric> #include <numeric>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class PDPattern;
// Some basic torminolygies: // Some basic terminologies:
// - PDPattern: a pattern defined as a data flow graph. // - PDPattern: a pattern defined as a data flow graph.
// - PDNode: the node in the pattern, each PDNode represents an `ir::Node` // - PDNode: the node in the pattern, each PDNode represents an `ir::Node`
// that meets some conditions defined in `PDNode.teller`. // that meets some conditions defined in `PDNode.teller`.
...@@ -36,30 +38,43 @@ namespace ir { ...@@ -36,30 +38,43 @@ namespace ir {
struct PDNode { struct PDNode {
// tell whether an ir::Node* is a candidation for a PDNode. // tell whether an ir::Node* is a candidation for a PDNode.
using teller_t = std::function<bool(Node*)>; using teller_t = std::function<bool(Node*)>;
enum class Type { kOp, kVar };
PDNode(teller_t&& teller, const std::string& name = "") // this link to others
: teller_(teller), name_(name) { PDNode& LinksTo(const std::vector<PDNode*>& others);
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set."); PDNode& LinksFrom(const std::vector<PDNode*>& others);
}
PDNode(PDNode&& other) = default;
std::vector<PDNode*> inlinks;
std::vector<PDNode*> outlinks;
bool Tell(Node* node) const { bool Tell(Node* node) const {
PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode"); PADDLE_ENFORCE(teller_ != nullptr, "teller should be set for a PDNode");
return teller_(node); return teller_(node);
} }
bool IsOp() const { return type_ == Type::kOp; }
bool IsVar() const { return type_ == Type::kVar; }
const std::string& name() const { return name_; } const std::string& name() const { return name_; }
PDNode(const PDNode&) = delete; PDNode(const PDNode&) = delete;
PDNode& operator=(const PDNode&) = delete; PDNode& operator=(const PDNode&) = delete;
private: private:
PDNode(teller_t&& teller, PDPattern* pattern, const std::string& name = "",
Type type = Type::kVar)
: teller_(std::move(teller)),
pattern_(pattern),
name_(name),
type_(type) {
PADDLE_ENFORCE(teller_ != nullptr, "invalid teller functer is set.");
}
PDNode(PDNode&& other) = default;
friend class PDPattern;
teller_t teller_; teller_t teller_;
PDPattern* pattern_;
std::string name_; std::string name_;
Type type_;
}; };
/* /*
...@@ -102,6 +117,8 @@ class PDPattern { ...@@ -102,6 +117,8 @@ class PDPattern {
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
std::string DotString() const;
private: private:
#ifdef PADDLE_WITH_TESTING #ifdef PADDLE_WITH_TESTING
FRIEND_TEST(PDPattern, AddEdge); FRIEND_TEST(PDPattern, AddEdge);
...@@ -117,7 +134,7 @@ class PDPattern { ...@@ -117,7 +134,7 @@ class PDPattern {
}; };
/* /*
* GraphPatternDetecter helps to detect the specific patterns in the graph. * GraphPatternDetector helps to detect the specific patterns in the graph.
* Input a pattern, output a list of the matched subgraphs/nodes. * Input a pattern, output a list of the matched subgraphs/nodes.
* This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.). * This helper can be used to support fuse(conv+batchnorm => batchnorm e.g.).
* *
...@@ -129,7 +146,7 @@ class PDPattern { ...@@ -129,7 +146,7 @@ class PDPattern {
* *
* Usage: * Usage:
* // Create a detector * // Create a detector
* GraphPatternDetecter detector; * GraphPatternDetector detector;
* // Define the detector's pattern, by adding PDNode and define the edges. * // Define the detector's pattern, by adding PDNode and define the edges.
* auto* node0 = detector.mutable_pattern().AddNode(...) * auto* node0 = detector.mutable_pattern().AddNode(...)
* auto* node1 = detector.mutable_pattern().AddNode(...) * auto* node1 = detector.mutable_pattern().AddNode(...)
...@@ -138,11 +155,11 @@ class PDPattern { ...@@ -138,11 +155,11 @@ class PDPattern {
* detector.mutable_pattern().AddEdge(node0, node1); * detector.mutable_pattern().AddEdge(node0, node1);
* // Create an handler, to define the behavior of treating the filtered * // Create an handler, to define the behavior of treating the filtered
* // subgraphs that comply with the patterns. * // subgraphs that comply with the patterns.
* GraphPatternDetecter::handle_t handler = some labmda * GraphPatternDetector::handle_t handler = some labmda
* // Execute the detector. * // Execute the detector.
* detector(&graph, handler); * detector(&graph, handler);
*/ */
class GraphPatternDetecter { class GraphPatternDetector {
public: public:
using subgraph_t = std::unordered_map<PDNode*, Node*>; using subgraph_t = std::unordered_map<PDNode*, Node*>;
...@@ -177,10 +194,62 @@ class GraphPatternDetecter { ...@@ -177,10 +194,62 @@ class GraphPatternDetecter {
using hit_rcd_t = using hit_rcd_t =
std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>; std::pair<Node* /*node in graph*/, PDNode* /*node in pattern*/>;
PDPattern pattern_; PDPattern pattern_;
std::vector<hit_rcd_t> marked_records_;
std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_; std::unordered_map<const PDNode*, std::unordered_set<Node*>> pdnodes2nodes_;
}; };
// some helper methods.
// Op's input.
static bool VarLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Op's output.
static bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Check whether a var node is a op node's nth input.
static bool IsNthInput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) { ...@@ -82,7 +82,7 @@ TEST(PDPattern, AddEdge) {
} }
TEST(GraphPatternDetecter, MarkPDNodesInGraph) { TEST(GraphPatternDetecter, MarkPDNodesInGraph) {
GraphPatternDetecter x; GraphPatternDetector x;
// mark o2, o3, v2 // mark o2, o3, v2
// The pattern is a graph: // The pattern is a graph:
...@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -131,7 +131,7 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
Graph graph(program); Graph graph(program);
BuildGraph(&graph); BuildGraph(&graph);
GraphPatternDetecter x; GraphPatternDetector x;
// The pattern is a graph: // The pattern is a graph:
// op -> var // op -> var
...@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) { ...@@ -149,8 +149,8 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
x.mutable_pattern()->AddEdge(any_var, any_op1); x.mutable_pattern()->AddEdge(any_var, any_op1);
int count = 0; int count = 0;
GraphPatternDetecter::handle_t handle = [&]( GraphPatternDetector::handle_t handle = [&](
const GraphPatternDetecter::subgraph_t& s, Graph* g) { const GraphPatternDetector::subgraph_t& s, Graph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> " LOG(INFO) << "Detect " << s.at(any_op)->Name() << " -> "
<< s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name(); << s.at(any_var)->Name() << " -> " << s.at(any_op1)->Name();
count++; count++;
......
...@@ -16,11 +16,13 @@ limitations under the License. */ ...@@ -16,11 +16,13 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kGraphVizPath[] = "graph_viz_path"; static const char kGraphVizPath[] = "graph_viz_path";
using inference::analysis::Dot;
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
...@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( ...@@ -30,41 +32,65 @@ std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
size_t var_id = 0; std::unordered_map<const ir::Node*, std::string> node2dot;
std::unordered_map<const ir::Node*, size_t> vars;
Dot dot;
sout << "digraph G {\n";
std::vector<Dot::Attr> op_attrs({Dot::Attr("style", "filled"),
for (const ir::Node* n : graph->Nodes()) { Dot::Attr("shape", "box"),
if (n->NodeType() != ir::Node::Type::kVariable) continue; Dot::Attr("fillcolor", "red")});
size_t cur_var_id = var_id++; std::vector<Dot::Attr> var_attrs({Dot::Attr("style", "filled,rounded"),
vars[n] = cur_var_id; // Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "yellow")});
sout << "var_" << cur_var_id << " [label=\"" << n->Name() << "\"]"
<< std::endl; std::vector<Dot::Attr> marked_op_attrs({Dot::Attr("style", "filled"),
Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "lightgray")});
std::vector<Dot::Attr> marked_var_attrs(
{Dot::Attr("style", "filled,rounded"),
// Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "lightgray")});
auto marked_nodes = ConsumeMarkedNodes(graph.get());
// Create nodes
for (const Node* n : graph->Nodes()) {
std::string node_id = n->Name() + "(" + std::to_string(n->id()) + ")";
if (n->IsOp()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_op_attrs : op_attrs;
dot.AddNode(node_id, attr, node_id);
} else if (n->IsVar()) {
decltype(op_attrs) attr =
marked_nodes.count(n) ? marked_var_attrs : var_attrs;
dot.AddNode(node_id, attr, node_id);
} }
node2dot[n] = node_id;
size_t op_id = 0;
for (const ir::Node* n : graph->Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << n->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : n->inputs) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
} }
// Create edges
for (auto out : n->outputs) { for (const Node* n : graph->Nodes()) {
std::string var_name = "var_" + std::to_string(vars[out]); const auto& src_id = node2dot.at(n);
sout << op_name << " -> " << var_name << std::endl; for (auto* out : n->outputs) {
const auto& trg_id = node2dot.at(out);
dot.AddEdge(src_id, trg_id, {});
} }
} }
sout << "}\n"; sout << dot.Build();
return graph; return graph;
} }
GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
Graph* graph) const {
marked_nodes_t res;
if (graph->Has(kGraphvizMarkedNodeAttr)) {
auto& attr = graph->Get<marked_nodes_t>(kGraphvizMarkedNodeAttr);
res = attr;
attr.clear();
}
return res;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -27,10 +27,19 @@ namespace paddle { ...@@ -27,10 +27,19 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
const char kGraphvizMarkedNodeAttr[] = "__graphviz__marked_node__";
class GraphVizPass : public Pass { class GraphVizPass : public Pass {
public:
using marked_nodes_t = std::unordered_set<const Node*>;
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
// Tell whether there are any marked nodes in the graph. Consume the
// corresponding attribute.
marked_nodes_t ConsumeMarkedNodes(Graph* graph) const;
}; };
} // namespace ir } // namespace ir
......
...@@ -29,20 +29,26 @@ class Node { ...@@ -29,20 +29,26 @@ class Node {
enum class Type { kOperation, kVariable }; enum class Type { kOperation, kVariable };
static constexpr char kControlDepVarName[] = "__control_var"; static constexpr char kControlDepVarName[] = "__control_var";
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type, int id = -1)
: name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} : name_(name),
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(id) {}
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc, int id = -1)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable) {} type_(Type::kVariable),
id_(id) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc, int id = -1)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation) {} type_(Type::kOperation),
id_(id) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -58,6 +64,8 @@ class Node { ...@@ -58,6 +64,8 @@ class Node {
return op_desc_.get(); return op_desc_.get();
} }
int id() const { return id_; }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; } bool IsVar() const { return type_ == Type::kVariable; }
...@@ -69,6 +77,7 @@ class Node { ...@@ -69,6 +77,7 @@ class Node {
std::unique_ptr<VarDesc> var_desc_; std::unique_ptr<VarDesc> var_desc_;
std::unique_ptr<OpDesc> op_desc_; std::unique_ptr<OpDesc> op_desc_;
Type type_; Type type_;
int id_;
private: private:
DISABLE_COPY_AND_ASSIGN(Node); DISABLE_COPY_AND_ASSIGN(Node);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/seq_concat_fc_fuse_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle {
namespace framework {
namespace ir {
struct FuseExpr {};
// sequence expand, concat fuse pattern, return concat's output
PDNode* BuildSeqExpandConcatPattern(PDPattern* pattern) {
// The following operators will be fused:
// concat
// sequence_expand
// sequence_expand
// The following variables will be treat as inputs:
// concat mid input, 0th input for fused op
// sequence_expand input, 1th input for fused op
// sequence_expand input, 2th input for fused op
// The following variables will be treat as outputs:
// concat output
// So the following variables will be removed:
// sequence-expand output
// sequence-expand output
// Three operators
auto* sequence_expand0 = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
},
"sequence_expand0");
auto* sequence_expand1 = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "sequence_expand";
},
"sequence_expand1");
auto* concat = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "concat" && // basic check
x->Op()->Input("X").size() == 3; // Special case
},
"concat");
auto* sequence_expand0_in = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
},
"sequence_expand0_in");
auto* sequence_expand1_in = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && VarLinksToOp(x, "sequence_expand");
},
"sequence_expand1_in");
// The variables
auto* sequence_expand0_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() &&
VarLinksFromOp(x, "sequence_expand") && // basic check
VarLinksToOp(x, "concat") && // is concat's input
IsNthInput(x, x->outputs[0], "X", 1); // X[0]
},
"sequence_expand0_out");
auto* sequence_expand1_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() &&
VarLinksFromOp(x, "sequence_expand") && // basic check
VarLinksToOp(x, "concat") && // is concat's input
IsNthInput(x, x->outputs[0], "X", 2); // x[2]
},
"sequence_expand1_out");
auto* concat_in0 = pattern->NewNode(
[](Node* x) { return x && x->IsVar() && VarLinksToOp(x, "concat"); },
"concat_in0");
auto* concat_out = pattern->NewNode(
[](Node* x) { return x && x->IsVar() && VarLinksFromOp(x, "concat"); },
"concat_out");
// Links
sequence_expand0->LinksFrom({sequence_expand0_in})
.LinksTo({sequence_expand0_out});
sequence_expand1->LinksFrom({sequence_expand1_in})
.LinksTo({sequence_expand1_out});
concat->LinksFrom({sequence_expand0_out, sequence_expand1_out, concat_in0})
.LinksTo({concat_out});
return concat_out;
}
PDNode* BuildFCPattern(PDPattern* pattern, PDNode* fc_x) {
PDNode* fc_w = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksToOp(x, "mul") && // link
x->Var()->Proto()->persistable(); // is a parameter
},
"fc_w");
PDNode* mul_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksFromOp(x, "mul") && // link
VarLinksToOp(x, "elementwise_add") && //
!x->Var()->Proto()->persistable(); // is a parameter
},
"mul_out");
PDNode* fc_mul = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "mul"; // basic
},
"fc_mul");
PDNode* fc_bias = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksToOp(x, "elementwise_add") && // link
x->Var()->Proto()->persistable(); // is a parameter
},
"fc_bias");
PDNode* elementwise_add = pattern->NewNode(
[](Node* x) {
return x && x->IsOp() && x->Op()->Type() == "elementwise_add";
},
"elementwise_add");
PDNode* add_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
VarLinksFromOp(x, "elementwise_add") && // link
!x->Var()->Proto()->persistable(); // is a parameter
},
"add_out");
std::set<std::string> acts({"sigmoid", "tanh", "relu", "identity"});
PDNode* act = pattern->NewNode(
[=](Node* x) {
return x && x->IsOp() && acts.count(x->Op()->Type());
},
"act");
PDNode* fc_out = pattern->NewNode(
[](Node* x) {
return x && x->IsVar() && // basic
!x->Var()->Proto()->persistable(); // is a parameter
},
"fc_out");
fc_mul->LinksFrom({fc_w, fc_x}).LinksTo({mul_out});
elementwise_add->LinksFrom({mul_out, fc_bias}).LinksTo({add_out});
act->LinksFrom({add_out}).LinksTo({fc_out});
return fc_out;
}
std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(graph.get());
GraphPatternDetector detector;
auto* pattern = detector.mutable_pattern();
auto* concat_out = BuildSeqExpandConcatPattern(pattern);
BuildFCPattern(pattern, concat_out);
#define GET_NODE(id, pattern) \
PADDLE_ENFORCE(subgraph.count(pattern.RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(pattern.RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "get one concat pattern";
// fc
GET_NODE(fc_w, detector.pattern());
GET_NODE(fc_bias, detector.pattern());
GET_NODE(act, detector.pattern());
GET_NODE(fc_out, detector.pattern());
// concat
GET_NODE(concat_in0, detector.pattern());
GET_NODE(sequence_expand0_in, detector.pattern());
GET_NODE(sequence_expand1_in, detector.pattern());
OpDesc op_desc;
op_desc.SetType("fusion_seqexpand_concat_fc");
op_desc.SetInput("X", {concat_in0->Name(), sequence_expand0_in->Name(),
sequence_expand1_in->Name()});
op_desc.SetInput("FCWeight", {fc_w->Name()});
op_desc.SetInput("FCBias", {fc_bias->Name()});
const std::string fc_out_tmp = fc_out->Name() + ".tmp";
param_scope()->Var(fc_out_tmp)->GetMutable<framework::LoDTensor>();
op_desc.SetOutput("FCOut", {fc_out_tmp});
op_desc.SetOutput("Out", {fc_out->Name()});
op_desc.SetAttr("fc_activation", act->Op()->Type());
auto* op_node = graph->CreateOpNode(&op_desc);
// Add links
#define NODE_LINKS(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
NODE_LINKS(fc_w, op_node);
NODE_LINKS(fc_bias, op_node);
NODE_LINKS(concat_in0, op_node);
NODE_LINKS(sequence_expand0_in, op_node);
NODE_LINKS(sequence_expand1_in, op_node);
NODE_LINKS(op_node, fc_out);
// Clean nodes.
std::unordered_set<const Node*> marked_nodes;
for (auto& item : subgraph) {
marked_nodes.insert(item.second);
}
marked_nodes.erase(fc_w);
marked_nodes.erase(fc_bias);
marked_nodes.erase(concat_in0);
marked_nodes.erase(sequence_expand0_in);
marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out);
GraphSafeRemoveNodes(graph, marked_nodes);
});
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(seq_concat_fc_fuse_pass,
paddle::framework::ir::SeqConcatFcFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class SeqConcatFcFusePass : public FusePassBase {
public:
virtual ~SeqConcatFcFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass) cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc set(analysis_deps
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor)
cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc analyzer.cc
helper.cc helper.cc
# passes # passes
...@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph ...@@ -10,11 +13,11 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
tensorrt_subgraph_node_mark_pass.cc tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc fluid_to_ir_pass.cc
model_store_pass.cc model_store_pass.cc
DEPS framework_proto proto_desc ir_pass_manager graph pass) DEPS ${analysis_deps})
cc_test(test_node SRCS node_tester.cc DEPS analysis) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests)
...@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET) ...@@ -31,7 +34,7 @@ function (inference_analysis_test TARGET)
endif() endif()
cc_test(${TARGET} cc_test(${TARGET}
SRCS "${analysis_test_SRCS}" SRCS "${analysis_test_SRCS}"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass ${analysis_test_EXTRA_DEPS} DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING) endif(WITH_TESTING)
...@@ -58,20 +61,25 @@ endif() ...@@ -58,20 +61,25 @@ endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
analysis_predictor
# ir # ir
fc_fuse_pass fc_fuse_pass
fc_lstm_fuse_pass
seq_concat_fc_fuse_pass
graph_viz_pass graph_viz_pass
infer_clean_graph_pass infer_clean_graph_pass
graph_pattern_detecter graph_pattern_detector
infer_clean_graph_pass infer_clean_graph_pass
attention_lstm_fuse_pass
paddle_inference_api
pass pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
--infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc EXTRA_DEPS paddle_inference_api)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc) inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc EXTRA_DEPS paddle_fluid)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
......
...@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -102,6 +102,19 @@ class DfgPassManagerImpl final : public DfgPassManager {
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) { void Analyzer::Run(Argument* argument) {
// Ungly support fluid-to-ir-pass
argument->Set(kFluidToIrPassesAttr,
new std::vector<std::string>({
// Manual update the passes here.
"graph_viz_pass", //
"infer_clean_graph_pass", "graph_viz_pass", //
"attention_lstm_fuse_pass", "graph_viz_pass", //
"fc_lstm_fuse_pass", "graph_viz_pass", //
"seq_concat_fc_fuse_pass", "graph_viz_pass", //
"fc_fuse_pass", "graph_viz_pass" //
}));
for (auto& x : data_) { for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument)); PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll(); x->RunAll();
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN");
...@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -264,39 +265,24 @@ void TestDituRNNPrediction(const std::string &model_path,
const std::string &data_path, int batch_size, const std::string &data_path, int batch_size,
bool use_analysis, bool activate_ir, bool use_analysis, bool activate_ir,
int num_times = 1) { int num_times = 1) {
FLAGS_IA_enable_ir = activate_ir;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
std::string model_out;
if (use_analysis) {
Argument argument(model_path);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
model_out = "./analysis.out";
ASSERT_TRUE(PathExists(model_out));
} else {
model_out = FLAGS_infer_ditu_rnn_model;
}
NativeConfig config; NativeConfig config;
config.prog_file = model_out + "/__model__"; config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = model_out + "/param"; config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
config.use_gpu = false; config.use_gpu = false;
config.device = 0; config.device = 0;
config.specify_input_name = true; config.specify_input_name = true;
auto predictor = auto base_predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config); CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(config);
auto predictor =
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(config);
std::vector<PaddleTensor> input_slots; std::vector<PaddleTensor> input_slots;
DataRecord data(data_path, batch_size); DataRecord data(data_path, batch_size);
// Prepare inputs. // Prepare inputs.
PrepareInputs(&input_slots, &data, batch_size); PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs, base_outputs;
base_predictor->Run(input_slots, &base_outputs);
Timer timer; Timer timer;
timer.tic(); timer.tic();
...@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -308,37 +294,25 @@ void TestDituRNNPrediction(const std::string &model_path,
<< ", latency: " << timer.toc() / num_times << "ms"; << ", latency: " << timer.toc() / num_times << "ms";
LOG(INFO) << "====================================="; LOG(INFO) << "=====================================";
for (auto &out : outputs) { PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1, size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; }); [](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
1, [](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data()); float *data = static_cast<float *>(out.data.data());
for (size_t i = 0; float *base_data = static_cast<float *>(base_out.data.data());
i < std::min(sizeof(ditu_rnn_target_data) / sizeof(float), size); for (size_t i = 0; i < size; i++) {
i++) { EXPECT_NEAR(data[i], base_data[i], 1e-3);
EXPECT_NEAR(data[i], ditu_rnn_target_data[i], 1e-3);
} }
} }
} }
// Turn on the IR pass supportion, run a real inference and check the result.
TEST(Analyzer, SupportIRPass) {
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE(PathExists("./analysis.out"));
// Inference from this path.
TestWord2vecPrediction("./analysis.out");
}
// Directly infer with the original model. // Directly infer with the original model.
TEST(Analyzer, DituRNN_without_analysis) { TEST(Analyzer, DituRNN_without_analysis) {
TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data, TestDituRNNPrediction(FLAGS_infer_ditu_rnn_model, FLAGS_infer_ditu_rnn_data,
...@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) { ...@@ -365,5 +339,8 @@ TEST(Analyzer, DituRNN_with_analysis_with_IR) {
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass); USE_PASS(fc_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass); USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -58,6 +59,46 @@ struct Argument { ...@@ -58,6 +59,46 @@ struct Argument {
// The output storage path of ModelStorePass. // The output storage path of ModelStorePass.
std::unique_ptr<std::string> model_output_store_path; std::unique_ptr<std::string> model_output_store_path;
// Support for any other attributes.
template <typename T>
void Set(const std::string& key, T* data) {
PADDLE_ENFORCE_NOT_NULL(data);
PADDLE_ENFORCE(!attrs_.count(key), "duplicate attr called %s", key);
attrs_[key] = data;
attr_deleters_[key] = [data, key, this]() {
VLOG(3) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
VLOG(3) << "argument delete attr: " << key;
delete data;
};
}
bool Has(const std::string& name) const { return attrs_.count(name); }
template <typename T>
T* Release(const std::string& key) {
PADDLE_ENFORCE(attrs_.count(key));
auto* res = boost::any_cast<T*>(attrs_.at(key));
attrs_.erase(key);
attr_deleters_.erase(key);
return res;
}
template <typename T>
T& Get(const std::string& key) {
PADDLE_ENFORCE(Has(key));
return *boost::any_cast<T*>(attrs_.at(key));
}
~Argument() {
for (auto& item : attr_deleters_) {
item.second();
}
}
private:
std::unordered_map<std::string, boost::any> attrs_;
std::unordered_map<std::string, std::function<void()>> attr_deleters_;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/proto_desc.h" #include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -65,6 +66,10 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
} }
if (argument_->Has("param_scope")) {
LOG(WARNING) << "parameter changes in the scope takes effect";
}
PADDLE_ENFORCE(argument_->transformed_program_desc.get()); PADDLE_ENFORCE(argument_->transformed_program_desc.get());
} }
......
...@@ -29,13 +29,13 @@ namespace paddle { ...@@ -29,13 +29,13 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
static size_t dot_node_counter{0};
/* /*
* A Dot template that helps to build a DOT graph definition. * A Dot template that helps to build a DOT graph definition.
*/ */
class Dot { class Dot {
public: public:
static size_t counter;
struct Attr { struct Attr {
std::string key; std::string key;
std::string value; std::string value;
...@@ -57,7 +57,7 @@ class Dot { ...@@ -57,7 +57,7 @@ class Dot {
Node(const std::string& name, const std::vector<Attr>& attrs) Node(const std::string& name, const std::vector<Attr>& attrs)
: name(name), : name(name),
attrs(attrs), attrs(attrs),
id_("node_" + std::to_string(Dot::counter++)) {} id_("node_" + std::to_string(dot_node_counter++)) {}
std::string id() const { return id_; } std::string id() const { return id_; }
...@@ -65,6 +65,10 @@ class Dot { ...@@ -65,6 +65,10 @@ class Dot {
std::stringstream ss; std::stringstream ss;
CHECK(!name.empty()); CHECK(!name.empty());
ss << id_; ss << id_;
if (attrs.empty()) {
ss << "[label=" << '"' << name << '"' << "]";
return ss.str();
}
for (size_t i = 0; i < attrs.size(); i++) { for (size_t i = 0; i < attrs.size(); i++) {
if (i == 0) { if (i == 0) {
ss << "[label=" << '"' << name << '"' << " "; ss << "[label=" << '"' << name << '"' << " ";
...@@ -108,9 +112,11 @@ class Dot { ...@@ -108,9 +112,11 @@ class Dot {
explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {} explicit Dot(const std::vector<Attr>& attrs) : attrs_(attrs) {}
void AddNode(const std::string& name, const std::vector<Attr>& attrs) { void AddNode(const std::string& id, const std::vector<Attr>& attrs,
CHECK(!nodes_.count(name)) << "duplicate Node '" << name << "'"; std::string label = "") {
nodes_.emplace(name, Node{name, attrs}); CHECK(!nodes_.count(id)) << "duplicate Node '" << id << "'";
if (label.empty()) label = id;
nodes_.emplace(id, Node{label, attrs});
} }
void AddEdge(const std::string& source, const std::string& target, void AddEdge(const std::string& source, const std::string& target,
......
...@@ -13,3 +13,47 @@ ...@@ -13,3 +13,47 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace inference {
namespace analysis {
void FluidToIrPass::EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file) {
PADDLE_ENFORCE(argument_);
argument_->Set("param_scope", new framework::Scope);
// Load parameters.
VLOG(3) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>("param_scope"), model_dir,
prog_file, param_file);
}
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file,
const std::string &param_file) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
framework::Executor executor(place);
PADDLE_ENFORCE(argument_->origin_program_desc.get());
framework::ProgramDesc program(*argument_->origin_program_desc);
if ((!prog_file.empty()) && (!param_file.empty())) {
LOG(INFO) << "load single model file from " << prog_file;
Load(&executor, scope, prog_file, param_file);
} else if (!dir.empty()) {
LOG(INFO) << "load from dir " << dir;
Load(&executor, scope, dir);
} else {
LOG(ERROR) << "failed to load parameters";
return false;
}
return true;
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -21,12 +21,17 @@ namespace paddle { ...@@ -21,12 +21,17 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
class FluidToIrPass final : public DataFlowGraphPass { class FluidToIrPass final : public DataFlowGraphPass {
public: public:
FluidToIrPass() = default; FluidToIrPass() = default;
bool Initialize(Argument *argument) override { bool Initialize(Argument *argument) override {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument); ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
PADDLE_ENFORCE(argument->Has(kFluidToIrPassesAttr),
"argument need the attr %s", kFluidToIrPassesAttr);
argument_ = argument;
if (argument->origin_program_desc) { if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might " LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called"; "duplicate called";
...@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -46,12 +51,21 @@ class FluidToIrPass final : public DataFlowGraphPass {
if (!argument->main_dfg) { if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph); argument->main_dfg.reset(new DataFlowGraph);
} }
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the argument->Set("ir_program_desc", new framework::ProgramDesc(program));
// address, will segfault if the original ProgramDesc destroys.
auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer(); LOG(INFO) << "Loading parameters";
ir_program_p = new framework::ProgramDesc(program); // Load parameters to argument if needed.
if (argument->fluid_model_dir || (argument->fluid_model_program_path &&
argument->fluid_model_param_path)) {
#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
SAFE_GET(fluid_model_dir);
SAFE_GET(fluid_model_program_path);
SAFE_GET(fluid_model_param_path);
#undef SAFE_GET
EnableParamModify(fluid_model_dir, fluid_model_program_path,
fluid_model_param_path);
}
argument_ = argument;
return true; return true;
} }
...@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass { ...@@ -59,20 +73,36 @@ class FluidToIrPass final : public DataFlowGraphPass {
void Run(DataFlowGraph *graph) override { void Run(DataFlowGraph *graph) override {
// Call all the IR Passes // Call all the IR Passes
IRPassManager ir_passes(*static_cast<framework::ProgramDesc *>( IRPassManager ir_passes(
argument_->main_dfg->Attr("ir_program_desc").Pointer())); argument_->Get<framework::ProgramDesc>("ir_program_desc"), nullptr);
ir_passes.Apply(std::vector<std::string>( // Pass the scope from analysis to IR if needed.
{// Manual update the passes here. if (argument_->Has("param_scope")) {
"graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass", // Here the address is passed, attention that IR doesn't own the scope, so
"fc_fuse_pass", "graph_viz_pass"})); // the real scope in analysis should live during the IR phase.
ir_passes.graph().Set(
"param_scope", new framework::Scope *(
&argument_->Get<framework::Scope>("param_scope")));
}
const auto &ir_passes_to_apply =
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
ir_passes.Apply(ir_passes_to_apply);
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph()); argument_->main_dfg->Build(ir_passes.graph());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
} }
void EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file);
std::string repr() const override { return "fluid-to-ir-pass"; } std::string repr() const override { return "fluid-to-ir-pass"; }
private:
// Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file);
private: private:
Argument *argument_{nullptr}; Argument *argument_{nullptr};
}; };
......
...@@ -24,6 +24,8 @@ namespace analysis { ...@@ -24,6 +24,8 @@ namespace analysis {
TEST(FluidToIrPass, Test) { TEST(FluidToIrPass, Test) {
FluidToIrPass pass; FluidToIrPass pass;
Argument argument(FLAGS_inference_model_dir); Argument argument(FLAGS_inference_model_dir);
argument.Set(kFluidToIrPassesAttr,
new std::vector<std::string>({"infer_clean_graph_pass"}));
pass.Initialize(&argument); pass.Initialize(&argument);
pass.Run(argument.main_dfg.get()); pass.Run(argument.main_dfg.get());
} }
...@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) { ...@@ -32,6 +34,9 @@ TEST(FluidToIrPass, Test) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass); USE_PASS(infer_clean_graph_pass);
USE_PASS(attention_lstm_fuse_pass);
USE_PASS(fc_lstm_fuse_pass);
USE_PASS(seq_concat_fc_fuse_pass);
USE_PASS(fc_fuse_pass);
...@@ -14,20 +14,24 @@ ...@@ -14,20 +14,24 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
IRPassManager::IRPassManager(const ProgramDesc& program) { IRPassManager::IRPassManager(const ProgramDesc &program,
framework::Scope *scope)
: program_(program) {
graph_.reset(new framework::ir::Graph(program)); graph_.reset(new framework::ir::Graph(program));
if (scope) graph_->Set("param_scope", new framework::Scope *(scope));
} }
void IRPassManager::Apply(const std::vector<std::string>& passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
graph_->Set("graph_viz_path", new std::string("./1.dot"));
// Apply all the passes // Apply all the passes
std::string pre_pass; std::string pre_pass;
for (const std::string& pass_name : passes) { for (const std::string &pass_name : passes) {
LOG(WARNING) << "Running IR pass [" << pass_name << "]"; LOG(WARNING) << "Running IR pass [" << pass_name << "]";
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -31,14 +32,15 @@ using framework::ProgramDesc; ...@@ -31,14 +32,15 @@ using framework::ProgramDesc;
class IRPassManager final { class IRPassManager final {
public: public:
IRPassManager(const ProgramDesc& program); IRPassManager(const ProgramDesc &program, framework::Scope *scope);
void Apply(const std::vector<std::string>& passes); void Apply(const std::vector<std::string> &passes);
framework::ir::Graph& graph() const { return *graph_; } framework::ir::Graph &graph() const { return *graph_; }
private: private:
std::unique_ptr<framework::ir::Graph> graph_; std::unique_ptr<framework::ir::Graph> graph_;
ProgramDesc program_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) { ...@@ -33,9 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
LOG(INFO) << "Total " << data_.size() << " passes"; LOG(INFO) << "Total " << data_.size() << " Analysys passes";
for (auto& pass : data_) { for (auto& pass : data_) {
LOG(WARNING) << "Running pass [" << pass->repr() << "]"; LOG(WARNING) << "Running Analysis pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get()); pass->Run(argument_->main_dfg.get());
} }
} }
......
...@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME) ...@@ -46,7 +46,8 @@ function(inference_api_test TARGET_NAME)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_api_test) endfunction(inference_api_test)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc DEPS lod_tensor) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor)
cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api)
cc_test(test_paddle_inference_api cc_test(test_paddle_inference_api
SRCS api_tester.cc SRCS api_tester.cc
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/utils/singleton.h"
namespace paddle {
using inference::analysis::Argument;
using inference::Singleton;
using inference::analysis::Analyzer;
using framework::proto::ProgramDesc;
/* This predictor is based on the original native predictor with IR and Analysis
* support. It will optimize IR and Parameters in the runtime.
* TODO(Superjomn) Replace the Navive predictor?
*/
class AnalysisPredictor : public NativePaddlePredictor {
public:
explicit AnalysisPredictor(const NativeConfig& config)
: NativePaddlePredictor(config), config_(config) {}
bool Init(const std::shared_ptr<framework::Scope>& parent_scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
place_ = paddle::platform::CUDAPlace(config_.device);
} else {
place_ = paddle::platform::CPUPlace();
}
PADDLE_ENFORCE(!parent_scope);
if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
// Initialize the inference program
if (!config_.model_dir.empty()) {
// Parameters are saved in separate files sited in
// the specified `dirname`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.model_dir);
} else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
// All parameters are saved in a single file.
// The file names should be consistent with that used
// in Python API `fluid.io.save_inference_model`.
inference_program_ = paddle::inference::Load(
executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
} else {
LOG(ERROR) << "fail to load inference model.";
return false;
}
OptimizeInferenceProgram();
ctx_ = executor_->Prepare(*inference_program_, 0);
VLOG(5) << "to create variables";
PADDLE_ENFORCE(scope_.get());
executor_->CreateVariables(*inference_program_,
sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames();
fetch_target_names_ = inference_program_->GetFetchTargetNames();
return true;
}
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override {
return NativePaddlePredictor::Run(inputs, output_data, batch_size);
}
void OptimizeInferenceProgram() {
LOG(INFO) << "optimize begin";
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = ""; // Don't output the model.
// Analyze inference_program
Argument argument;
if (!config_.model_dir.empty()) {
argument.fluid_model_dir.reset(new std::string(config_.model_dir));
} else {
PADDLE_ENFORCE(
!config_.param_file.empty(),
"Either model_dir or (param_file, prog_file) should be set.");
PADDLE_ENFORCE(!config_.prog_file.empty());
argument.fluid_model_program_path.reset(
new std::string(config_.prog_file));
argument.fluid_model_param_path.reset(
new std::string(config_.param_file));
}
argument.origin_program_desc.reset(
new ProgramDesc(*inference_program_->Proto()));
Singleton<Analyzer>::Global().Run(&argument);
CHECK(argument.transformed_program_desc);
VLOG(5) << "to prepare executor";
// LOG(INFO) << "transformed_parogram_desc " <<
// argument.transformed_program_desc->DebugString();
inference_program_.reset(
new framework::ProgramDesc(*argument.transformed_program_desc));
PADDLE_ENFORCE(argument.Has("param_scope"));
// Update scope.
scope_.reset(argument.Release<framework::Scope>("param_scope"));
LOG(INFO) << "optimize end ==";
}
private:
NativeConfig config_;
};
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) {
VLOG(3) << "create NativePredictor";
if (config.use_gpu) {
// 1. GPU memeroy
PADDLE_ENFORCE_GT(
config.fraction_of_gpu_memory, 0.f,
"fraction_of_gpu_memory in the config should be set to range (0., 1.]");
PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
std::vector<std::string> flags;
if (config.fraction_of_gpu_memory >= 0.0f ||
config.fraction_of_gpu_memory <= 0.95f) {
flags.push_back("dummpy");
std::string flag = "--fraction_of_gpu_memory_to_use=" +
std::to_string(config.fraction_of_gpu_memory);
flags.push_back(flag);
VLOG(3) << "set flag: " << flag;
framework::InitGflags(flags);
}
}
std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
if (!dynamic_cast<AnalysisPredictor*>(predictor.get())->Init(nullptr)) {
return nullptr;
}
return predictor;
}
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
namespace paddle {
namespace inference {
template <>
std::string to_string<std::vector<float>>(
const std::vector<std::vector<float>> &vec) {
std::stringstream ss;
for (const auto &piece : vec) {
ss << to_string(piece) << "\n";
}
return ss.str();
}
template <>
std::string to_string<std::vector<std::vector<float>>>(
const std::vector<std::vector<std::vector<float>>> &vec) {
std::stringstream ss;
for (const auto &line : vec) {
for (const auto &rcd : line) {
ss << to_string(rcd) << ";\t";
}
ss << '\n';
}
return ss.str();
}
} // namespace inference
} // namespace paddle
...@@ -44,7 +44,8 @@ class Timer { ...@@ -44,7 +44,8 @@ class Timer {
} }
}; };
void split(const std::string &str, char sep, std::vector<std::string> *pieces) { static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
if (str.empty()) { if (str.empty()) {
return; return;
...@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) { ...@@ -60,7 +61,8 @@ void split(const std::string &str, char sep, std::vector<std::string> *pieces) {
pieces->push_back(str.substr(pos)); pieces->push_back(str.substr(pos));
} }
} }
void split_to_float(const std::string &str, char sep, std::vector<float> *fs) { static void split_to_float(const std::string &str, char sep,
std::vector<float> *fs) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(str, sep, &pieces); split(str, sep, &pieces);
std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs), std::transform(pieces.begin(), pieces.end(), std::back_inserter(*fs),
...@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) { ...@@ -76,27 +78,14 @@ std::string to_string(const std::vector<T> &vec) {
} }
template <> template <>
std::string to_string<std::vector<float>>( std::string to_string<std::vector<float>>(
const std::vector<std::vector<float>> &vec) { const std::vector<std::vector<float>> &vec);
std::stringstream ss;
for (const auto &piece : vec) {
ss << to_string(piece) << "\n";
}
return ss.str();
}
template <> template <>
std::string to_string<std::vector<std::vector<float>>>( std::string to_string<std::vector<std::vector<float>>>(
const std::vector<std::vector<std::vector<float>>> &vec) { const std::vector<std::vector<std::vector<float>>> &vec);
std::stringstream ss;
for (const auto &line : vec) {
for (const auto &rcd : line) {
ss << to_string(rcd) << ";\t";
}
ss << '\n';
}
return ss.str();
}
// clang-format off // clang-format off
void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) { static void TensorAssignData(PaddleTensor *tensor, const std::vector<std::vector<float>> &data) {
// Assign buffer // Assign buffer
int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; }); int dim = std::accumulate(tensor->shape.begin(), tensor->shape.end(), 1, [](int a, int b) { return a * b; });
tensor->data.Resize(sizeof(float) * dim); tensor->data.Resize(sizeof(float) * dim);
......
...@@ -77,6 +77,7 @@ enum class PaddleEngineKind { ...@@ -77,6 +77,7 @@ enum class PaddleEngineKind {
kNative = 0, // Use the native Fluid facility. kNative = 0, // Use the native Fluid facility.
kAnakin, // Use Anakin for inference. kAnakin, // Use Anakin for inference.
kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT. kAutoMixedTensorRT, // Automatically mix Fluid with TensorRT.
kAnalysis
// TODO(Superjomn) support following engines latter. // TODO(Superjomn) support following engines latter.
// kTensorRT, // Use TensorRT for inference. // kTensorRT, // Use TensorRT for inference.
// kAutoMixedAnakin, // Automatically mix Fluid with Anakin. // kAutoMixedAnakin, // Automatically mix Fluid with Anakin.
......
...@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load( ...@@ -143,5 +143,21 @@ std::unique_ptr<framework::ProgramDesc> Load(
return main_program; return main_program;
} }
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
bool predicate) {
framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0);
auto* op = block->AppendOp();
op->SetType("save_combine");
op->SetInput("X", vars);
op->SetAttr("file_path", dirname + "/param");
op->CheckAttrs();
platform::CPUPlace place;
framework::Executor exe(place);
exe.Run(prog, const_cast<framework::Scope*>(&scope), 0, true, true);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor, ...@@ -41,5 +41,10 @@ std::unique_ptr<framework::ProgramDesc> Load(framework::Executor* executor,
const std::string& prog_filename, const std::string& prog_filename,
const std::string& param_filename); const std::string& param_filename);
// Save the variables from a scope to disk.
void SaveVars(const framework::Scope& scope,
const std::vector<std::string>& vars, const std::string& dirname,
bool predicate = true);
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -56,7 +56,7 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
const int D = w_dims[1] / 4; const int D = w_dims[1] / 4;
PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2."); PADDLE_ENFORCE_EQ(w_dims.size(), 2, "Input(LSTMWeight)'s rank must be 2.");
PADDLE_ENFORCE_EQ(w_dims[0], D + M, PADDLE_ENFORCE_EQ(w_dims[0], D + M,
"LSTMWeight dims should be (%d + %d) * %d.", D + M, 4 * D); "LSTMWeight dims should be (%d + %d) * %d.", D, M, 4 * D);
auto b_dims = ctx->GetInputDim("LSTMBias"); auto b_dims = ctx->GetInputDim("LSTMBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(LSTMBias)'s rank must be 2.");
......
...@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -49,9 +49,14 @@ void FusionSeqExpandConcatFCOp::InferShape(
"FC height should be sum of all inputs width."); "FC height should be sum of all inputs width.");
if (ctx->HasInput("FCBias")) { if (ctx->HasInput("FCBias")) {
auto b_dims = ctx->GetInputDim("FCBias"); auto b_dims = ctx->GetInputDim("FCBias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "Input(FCBias)'s rank must be 2."); PADDLE_ENFORCE(b_dims.size() == 1 || b_dims.size() == 2,
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1 * %d.", D); "b_dims should be 1 or 2, get %d", b_dims.size());
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1 * %d.", D); if (b_dims.size() == 1) {
PADDLE_ENFORCE_EQ(b_dims[0], D, "FCBias shapes must be %d.", D);
} else {
PADDLE_ENFORCE_EQ(b_dims[0], 1, "FCBias shapes must be 1x%d.", D);
PADDLE_ENFORCE_EQ(b_dims[1], D, "FCBias shapes must be 1x%d.", D);
}
} }
ctx->SetOutputDim("Out", {ins_dims[0][0], D}); ctx->SetOutputDim("Out", {ins_dims[0][0], D});
......
...@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) { ...@@ -85,9 +85,6 @@ void InitDevices(bool init_p2p) {
} catch (const std::exception &exp) { } catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
} }
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif #endif
InitDevices(init_p2p, devices); InitDevices(init_p2p, devices);
} }
...@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) { ...@@ -101,9 +98,6 @@ void InitDevices(bool init_p2p, const std::vector<int> devices) {
} catch (const std::exception &exp) { } catch (const std::exception &exp) {
LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime."; LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
} }
#else
LOG(WARNING)
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif #endif
for (size_t i = 0; i < devices.size(); ++i) { for (size_t i = 0; i < devices.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册