未验证 提交 896a37b6 编写于 作者: Y Yan Chunwei 提交者: GitHub

fea/link ir to inference analysis and fc fuse support (#12789)

* link IR graph to analysis graph

* add clean code and update

* add infer_clean_pass

* add ir_pass_manager

* support fc fuse executation

* fix ir circle
上级 e23ddf6a
...@@ -28,7 +28,7 @@ def get_symbol(num_classes=10, **kwargs): ...@@ -28,7 +28,7 @@ def get_symbol(num_classes=10, **kwargs):
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null. Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own AnyAttr. There is a op field in AnyAttr class, when a Symbol represents Variable(often input data), the op field is null.
Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph. Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph.
......
...@@ -5,8 +5,12 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper) ...@@ -5,8 +5,12 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
bool VarOutLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
void BuildFCPattern(PDPattern* pattern) {
// make sure the selected MUL op has one input argument is a parameter.
auto* mul_parameter_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() == 1UL &&
node->outputs.front()->Op()->Type() == "mul" && node->Var() &&
node->Var()->Persistable(); // check is a parameter
},
"mul_weight" /*name*/);
auto* mul_tmp_input_var = pattern->NewNode(
[](Node* node) {
bool result =
node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
!node->Var()->Persistable(); // this input is not an parameter.
if (!result) return false;
// check whether one output is MUL op.
for (auto* op : node->outputs) {
if (op->IsOp() && op->Op()->Type() == "mul") return true;
}
return false;
},
"mul_tmp_var" /*name*/);
// select a MUL op
auto* mul_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && // start from an Op
node->Op()->Type() == "mul"; // type is mul
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul" /*name*/);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto* mul_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && // starts from a Var
node->outputs.size() == 1UL && // only has one consumer
node->outputs.front()->IsOp() && // check basic logic
node->Var() && // not a ControlDepVar
node->outputs.front()->Op()->Type() ==
"elementwise_add"; // a very strong validation
},
"mul_out");
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto* elementwise_add_tmp_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
VarOutLinksToOp(node, "elementwise_add");
},
"elementwise_add_tmpvar");
// select an ELEMENTWISE_ADD op
auto* elementwise_add_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && node->Op()->Type() == "elementwise_add";
},
"elementwise_add" /*name*/);
// get the ELEMENTWISE_ADD op's output
auto* elementwise_add_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
node->inputs.front()->Op()->Type() == "elementwise_add";
},
"elementwise_add_out");
pattern->AddEdge(mul_parameter_var, mul_op);
pattern->AddEdge(mul_tmp_input_var, mul_op);
pattern->AddEdge(mul_op, mul_out_var);
pattern->AddEdge(mul_out_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
}
// Replace the node `from` in the links to `to`
bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
for (auto*& n : *links) {
if (n == from) {
n = to;
return true;
}
}
return false;
}
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
std::unordered_set<Node*> nodes2delete;
GraphPatternDetecter gpd;
BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the
// scenerio.
// FC's fusion is simple, just op fuse, no need to process the
// parameters.
GET_NODE(mul_tmp_var); // x
GET_NODE(mul_weight); // Y
GET_NODE(elementwise_add_tmpvar); // bias
GET_NODE(elementwise_add_out); // Out
GET_NODE(mul); // MUL op
GET_NODE(elementwise_add); // ELEMENT_ADD op
GET_NODE(mul_out); // tmp
#undef GET_NODE
// Create an FC Node.
OpDesc desc;
std::string fc_x_in = mul_tmp_var->Name();
std::string fc_Y_in = mul_weight->Name();
std::string fc_bias_in = elementwise_add_tmpvar->Name();
std::string fc_out = elementwise_add_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out}));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
fc_node->inputs =
std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
fc_node->outputs.push_back(elementwise_add_out);
// Update link relatons
PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
elementwise_add, fc_node));
PADDLE_ENFORCE(
LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
// Drop old nodes
graph->RemoveNode(mul);
graph->RemoveNode(elementwise_add);
graph->RemoveNode(mul_out); // tmp variable
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
class FCFusePass : public Pass {
public:
virtual ~FCFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Ys", outputs);
}
// a->OP0->b
// a->OP1->c
// (b, c)->mul->d
// (d, e)->elementwise_add->f
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
return prog;
}
TEST(FCFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int after_nodes = graph->Nodes().size();
// Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out
// Add 1 Node: FC
EXPECT_EQ(pre_nodes - 2, after_nodes);
// Assert fc op in newly generated graph
int fc_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
++fc_count;
}
}
EXPECT_EQ(fc_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fc_fuse_pass);
...@@ -98,11 +98,13 @@ class Graph { ...@@ -98,11 +98,13 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc));
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc));
} }
...@@ -134,6 +136,14 @@ class Graph { ...@@ -134,6 +136,14 @@ class Graph {
return ret; return ret;
} }
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
const ProgramDesc &program() const { return program_; }
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
...@@ -143,12 +153,6 @@ class Graph { ...@@ -143,12 +153,6 @@ class Graph {
return node; return node;
} }
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc &program_; const ProgramDesc &program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
......
...@@ -25,12 +25,30 @@ namespace paddle { ...@@ -25,12 +25,30 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
size_t PDPattern::id_ = 0UL;
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
name);
}
nodes_.emplace_back(new PDNode(std::move(teller), name)); nodes_.emplace_back(new PDNode(std::move(teller), name));
auto* cur = nodes_.back().get(); auto* cur = nodes_.back().get();
node_map_[name] = cur;
return cur; return cur;
} }
PDNode* PDPattern::RetriveNode(const std::string& id) const {
auto it = node_map_.find(id);
if (it == node_map_.end()) {
return nullptr;
}
return it->second;
}
void PDPattern::AddEdge(PDNode* a, PDNode* b) { void PDPattern::AddEdge(PDNode* a, PDNode* b) {
PADDLE_ENFORCE(a); PADDLE_ENFORCE(a);
PADDLE_ENFORCE(b); PADDLE_ENFORCE(b);
...@@ -51,15 +69,18 @@ void GraphPatternDetecter::operator()(Graph* graph, ...@@ -51,15 +69,18 @@ void GraphPatternDetecter::operator()(Graph* graph,
} }
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
for (auto& node : GraphTraits::DFS(graph)) { for (auto& node : GraphTraits::DFS(graph)) {
for (const auto& pdnode : pattern_.nodes()) { for (const auto& pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) { if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked";
pdnodes2nodes_[pdnode.get()].insert(&node); pdnodes2nodes_[pdnode.get()].insert(&node);
} }
} }
} }
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
} }
...@@ -67,10 +88,20 @@ struct HitGroup { ...@@ -67,10 +88,20 @@ struct HitGroup {
std::unordered_map<PDNode*, Node*> roles; std::unordered_map<PDNode*, Node*> roles;
bool Match(Node* node, PDNode* pat) { bool Match(Node* node, PDNode* pat) {
if (nodes_.count(node)) {
if (!roles.count(pat)) return false;
return roles[pat] == node;
}
return !roles.count(pat) || roles.at(pat) == node; return !roles.count(pat) || roles.at(pat) == node;
} }
void Register(Node* node, PDNode* pat) { roles[pat] = node; } void Register(Node* node, PDNode* pat) {
roles[pat] = node;
nodes_.insert(node);
}
private:
std::unordered_set<Node*> nodes_;
}; };
// Tell whether Node a links to b. // Tell whether Node a links to b.
...@@ -104,6 +135,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -104,6 +135,7 @@ GraphPatternDetecter::DetectPatterns() {
// Extend a PDNode to subgraphs by deducing the connection relations defined // Extend a PDNode to subgraphs by deducing the connection relations defined
// in edges of PDNodes. // in edges of PDNodes.
for (const auto& edge : pattern_.edges()) { for (const auto& edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// Each role has two PDNodes, which indicates two roles. // Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected. // Detect two Nodes that can match these two roles and they are connected.
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
...@@ -127,6 +159,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -127,6 +159,7 @@ GraphPatternDetecter::DetectPatterns() {
} }
} }
} }
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
......
...@@ -96,7 +96,8 @@ class PDPattern { ...@@ -96,7 +96,8 @@ class PDPattern {
void AddEdge(PDNode* a, PDNode* b); void AddEdge(PDNode* a, PDNode* b);
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = ""); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID());
PDNode* RetriveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
...@@ -107,8 +108,12 @@ class PDPattern { ...@@ -107,8 +108,12 @@ class PDPattern {
FRIEND_TEST(PDPattern, NewNode); FRIEND_TEST(PDPattern, NewNode);
#endif #endif
static std::string NewID() { return "pdnode-" + std::to_string(id_++); }
std::vector<std::unique_ptr<PDNode>> nodes_; std::vector<std::unique_ptr<PDNode>> nodes_;
std::vector<edge_t> edges_; std::vector<edge_t> edges_;
std::unordered_map<std::string, PDNode*> node_map_;
static size_t id_;
}; };
/* /*
......
...@@ -25,6 +25,7 @@ static const char kGraphVizPath[] = "graph_viz_path"; ...@@ -25,6 +25,7 @@ static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class InferCleanGraphPass : public Pass {
public:
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<Node*> invalid_nodes;
for (auto* node : graph->Nodes()) {
if (is_valid_node(node)) {
invalid_nodes.insert(node);
}
}
// remove nodes from the graph.
for (auto* node : invalid_nodes) {
graph->RemoveNode(node);
}
// clean edges.
for (auto* node : graph->Nodes()) {
CleanEdges(&node->inputs, invalid_nodes);
CleanEdges(&node->outputs, invalid_nodes);
}
return graph;
}
void CleanEdges(std::vector<Node*>* nodes,
const std::unordered_set<Node*>& to_remove) const {
auto it = std::remove_if(nodes->begin(), nodes->end(),
[&](Node* x) { return to_remove.count(x); });
nodes->erase(it, nodes->end());
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(infer_clean_graph_pass,
paddle::framework::ir::InferCleanGraphPass);
...@@ -34,14 +34,15 @@ class Node { ...@@ -34,14 +34,15 @@ class Node {
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(var_desc), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable) {} type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(op_desc), op_desc_(new OpDesc(*op_desc)), // TODO(panyx0718) the pointer in the
// original OpDesc might go out.
type_(Type::kOperation) {} type_(Type::kOperation) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -50,12 +51,12 @@ class Node { ...@@ -50,12 +51,12 @@ class Node {
VarDesc* Var() { VarDesc* Var() {
PADDLE_ENFORCE(type_ == Type::kVariable); PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_; return var_desc_.get();
} }
OpDesc* Op() { OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation); PADDLE_ENFORCE(IsOp());
return op_desc_; return op_desc_.get();
} }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
...@@ -66,8 +67,8 @@ class Node { ...@@ -66,8 +67,8 @@ class Node {
protected: protected:
const std::string name_; const std::string name_;
VarDesc* var_desc_; std::unique_ptr<VarDesc> var_desc_;
OpDesc* op_desc_; std::unique_ptr<OpDesc> op_desc_;
Type type_; Type type_;
private: private:
......
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc
helper.cc
# passes
fluid_to_data_flow_graph_pass.cc fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc data_flow_graph_to_fluid_pass.cc
dfg_graphviz_draw_pass.cc dfg_graphviz_draw_pass.cc
tensorrt_subgraph_pass.cc tensorrt_subgraph_pass.cc
tensorrt_subgraph_node_mark_pass.cc tensorrt_subgraph_node_mark_pass.cc
analyzer.cc fluid_to_ir_pass.cc
helper.cc
model_store_pass.cc model_store_pass.cc
DEPS framework_proto proto_desc) DEPS framework_proto proto_desc ir_pass_manager graph pass)
cc_test(test_node SRCS node_tester.cc DEPS analysis gflags glog gtest)
cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
...@@ -27,19 +31,30 @@ function (inference_analysis_test TARGET) ...@@ -27,19 +31,30 @@ function (inference_analysis_test TARGET)
endif() endif()
cc_test(${TARGET} cc_test(${TARGET}
SRCS "${analysis_test_SRCS}" SRCS "${analysis_test_SRCS}"
DEPS analysis DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_analysis_test) endfunction(inference_analysis_test)
cc_test(test_analyzer SRCS analyzer_tester.cc DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
# ir
fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detecter
pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
#set_tests_properties(test_analyzer PROPERTIES DEPENDS test_word2vec)
#inference_api_test(test_analyzer SRC analyzer_tester.cc ARGS test_word2vec)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc) inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h" #include "paddle/fluid/inference/analysis/model_store_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
...@@ -24,14 +25,15 @@ ...@@ -24,14 +25,15 @@
namespace paddle { namespace paddle {
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true, DEFINE_bool(IA_enable_tensorrt_subgraph_engine, false,
"Enable subgraph to TensorRT engine for acceleration"); "Enable subgraph to TensorRT engine for acceleration");
DEFINE_string(inference_analysis_graphviz_log_root, "./", DEFINE_bool(IA_enable_ir, false, "Turn on IR support");
DEFINE_string(IA_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs."); "Graphviz debuger for data flow graphs.");
DEFINE_string(inference_analysis_output_storage_path, "", DEFINE_string(IA_output_storage_path, "", "optimized model output path");
"optimized model output path");
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -40,8 +42,34 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -40,8 +42,34 @@ class DfgPassManagerImpl final : public DfgPassManager {
public: public:
DfgPassManagerImpl() { DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs. // TODO(Superjomn) set the key with pass reprs.
LOG(INFO)
<< "-----------------------------------------------------------------";
if (FLAGS_IA_enable_ir) {
AddPass("fluid-to-ir-pass", new FluidToIrPass);
} else {
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass); AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { }
TryAddTensorRtPass();
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_IA_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
LOG(INFO)
<< "-----------------------------------------------------------------";
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
VLOG(3) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
}
void TryAddTensorRtPass() {
if (FLAGS_IA_enable_tensorrt_subgraph_engine) {
auto trt_teller = [&](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"}); {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"});
...@@ -59,20 +87,6 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -59,20 +87,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
new TensorRTSubgraphNodeMarkPass(trt_teller)); new TensorRTSubgraphNodeMarkPass(trt_teller));
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller)); AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
} }
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_inference_analysis_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
LOG(INFO) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
} }
// Add the graphviz debuger pass if the parent pass has one. // Add the graphviz debuger pass if the parent pass has one.
......
...@@ -43,9 +43,10 @@ namespace paddle { ...@@ -43,9 +43,10 @@ namespace paddle {
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this // TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available. // flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine); DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root); DECLARE_string(IA_graphviz_log_root);
DECLARE_string(inference_analysis_output_storage_path); DECLARE_string(IA_output_storage_path);
DECLARE_bool(IA_enable_ir);
namespace inference { namespace inference {
namespace analysis { namespace analysis {
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false; FLAGS_IA_enable_tensorrt_subgraph_engine = false;
Argument argument; Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
...@@ -29,13 +31,73 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -29,13 +31,73 @@ TEST(Analyzer, analysis_without_tensorrt) {
} }
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument; Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
void TestWord2vecPrediction(const std::string& model_path) {
NativeConfig config;
config.model_dir = model_path;
config.use_gpu = false;
config.device = 0;
auto predictor =
::paddle::CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
config);
// One single batch
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor;
tensor.shape = std::vector<int>({4, 1});
tensor.data = PaddleBuf(data, sizeof(data));
tensor.dtype = PaddleDType::INT64;
// For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> slots(4, tensor);
std::vector<PaddleTensor> outputs;
CHECK(predictor->Run(slots, &outputs));
PADDLE_ENFORCE(outputs.size(), 1UL);
// Check the output buffer size and result of each tid.
PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815,
0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << "data: "
<< static_cast<float*>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]);
}
}
// Turn on the IR pass supportion, run a real inference and check the result.
TEST(Analyzer, SupportIRPass) {
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE(PathExists("./analysis.out"));
// Inference from this path.
TestWord2vecPrediction("./analysis.out");
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
...@@ -19,14 +19,16 @@ limitations under the License. */ ...@@ -19,14 +19,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using ir_node_t = framework::ir::Node;
using ir_graph_t = framework::ir::Graph;
// It is a better idea that the inputs and outputs of this graph is set manually // It is a better idea that the inputs and outputs of this graph is set manually
// before, but there must be a Pass that helps to prune the unnecessary ops that // before, but there must be a Pass that helps to prune the unnecessary ops that
// do not contribute to the given targets, so in this pass, analysis and get the // do not contribute to the given targets, so in this pass, analysis and get the
// inputs and outputs is OK. // inputs and outputs is OK.
void DataFlowGraph::Build() { void DataFlowGraph::Build() {
inputs.clear(); inputs_.clear();
outputs.clear(); outputs_.clear();
std::unordered_set<Node *> ins; std::unordered_set<Node *> ins;
std::unordered_set<Node *> outs; std::unordered_set<Node *> outs;
for (auto &node : nodes.nodes()) { for (auto &node : nodes.nodes()) {
...@@ -42,18 +44,140 @@ void DataFlowGraph::Build() { ...@@ -42,18 +44,140 @@ void DataFlowGraph::Build() {
// similarly, the nodes that in outs but not in ins is the graphs' outputs // similarly, the nodes that in outs but not in ins is the graphs' outputs
for (auto *in : ins) { for (auto *in : ins) {
if (!outs.count(in)) { if (!outs.count(in)) {
inputs.push_back(in); inputs_.push_back(in);
} }
} }
for (auto *out : outs) { for (auto *out : outs) {
if (!outs.count(out)) { if (!ins.count(out)) {
outputs.push_back(out); outputs_.push_back(out);
} }
} }
Clean(); Clean();
} }
void DataFlowGraph::Build(const framework::proto::ProgramDesc &prog) {
// insert vars
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std::unordered_map<std::string, size_t> var2id;
auto &main_block = prog.blocks(framework::kRootBlockIndex);
for (int i = 0; i < main_block.vars_size(); i++) {
const auto &var = main_block.vars(i);
auto *v = nodes.Create(Node::Type::kValue);
v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id();
}
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std::unordered_set<Node *> unique_written_vars;
for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i);
auto *o = nodes.Create(Node::Type::kFunction);
o->SetName(op.type());
static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs
for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o);
o->inlinks.push_back(in);
}
}
for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = nodes.GetMutable(var2id[out_var.arguments(k)]);
if (unique_written_vars.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] =
out_alias->id(); // update variable's alias Node
LOG(INFO) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias;
}
out->inlinks.push_back(o);
o->outlinks.push_back(out);
unique_written_vars.insert(out);
}
}
}
// Analysis and extract the inputs and outputs of this graph.
Build();
}
void DataFlowGraph::Build(const framework::ir::Graph &graph) {
// Create nodes
std::unordered_map<ir_node_t *, Node *> ir_node_map;
for (auto *ir_node : graph.Nodes()) {
Node *x{nullptr};
if (ir_node->IsOp()) {
PADDLE_ENFORCE(ir_node->Op());
VLOG(4) << "get op " << ir_node << " " << ir_node->Name();
x = nodes.Create(Node::Type::kFunction);
x->attr("ir_node").Pointer() = ir_node;
PADDLE_ENFORCE(ir_node->Op()->Proto());
x->SetName(ir_node->Op()->Proto()->type());
x->SetPbMsg(ir_node->Op()->Proto()->SerializeAsString());
} else if (ir_node->IsVar()) {
// Not create a Node for IR ControlDepVar, considering Inference currently
// just used in single thread scenerio.
VLOG(4) << "get var " << ir_node->Name();
x = nodes.Create(Node::Type::kValue);
x->attr("ir_node").Pointer() = ir_node;
x->SetName(ir_node->Name());
// x->SetPbMsg(ir_node->Var()->Proto()->SerializeAsString());
} else {
PADDLE_THROW("Failed to create an Node from IR, unknown type");
}
ir_node_map.emplace(ir_node, x);
}
VLOG(4) << "finish creating Nodes";
VLOG(4) << "to create edge";
// Create links
for (auto *ir_node : graph.Nodes()) {
auto it = ir_node_map.find(ir_node);
// Skip ControlDepVar.
if (it == ir_node_map.end()) continue;
auto *node = it->second;
for (auto *x : ir_node->inputs) {
if (!ir_node_map.count(x)) continue;
node->inlinks.push_back(ir_node_map.at(x));
}
for (auto *x : ir_node->outputs) {
if (!ir_node_map.count(x)) continue;
node->outlinks.push_back(ir_node_map.at(x));
}
}
Build();
PADDLE_ENFORCE(!inputs_.empty(),
"Can't deduce any inputs from the graph, Is the graph empty?");
ir_graph = &graph;
VLOG(3) << "finished build from IR";
}
void DataFlowGraph::Clean() { void DataFlowGraph::Clean() {
for (auto &node : nodes.nodes()) { for (auto &node : nodes.nodes()) {
std::unordered_set<Node *> inlinks_set(node->inlinks.begin(), std::unordered_set<Node *> inlinks_set(node->inlinks.begin(),
...@@ -61,11 +185,9 @@ void DataFlowGraph::Clean() { ...@@ -61,11 +185,9 @@ void DataFlowGraph::Clean() {
std::unordered_set<Node *> outlinks_set(node->outlinks.begin(), std::unordered_set<Node *> outlinks_set(node->outlinks.begin(),
node->outlinks.end()); node->outlinks.end());
if (inlinks_set.size() < node->inlinks.size()) { if (inlinks_set.size() < node->inlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->inlinks.assign(inlinks_set.begin(), inlinks_set.end()); node->inlinks.assign(inlinks_set.begin(), inlinks_set.end());
} }
if (outlinks_set.size() < node->outlinks.size()) { if (outlinks_set.size() < node->outlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->outlinks.assign(outlinks_set.begin(), outlinks_set.end()); node->outlinks.assign(outlinks_set.begin(), outlinks_set.end());
} }
} }
...@@ -112,10 +234,10 @@ GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator( ...@@ -112,10 +234,10 @@ GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
const std::vector<Node *> &source) const std::vector<Node *> &source)
: queue_(source.begin(), source.end()) {} : queue_(source.begin(), source.end()) {}
// GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator( GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
// GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept
// : queue_(std::move(other.queue_)), : queue_(std::move(other.queue_)),
// visited_(std::move(other.visited_)) {} visited_(std::move(other.visited_)) {}
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator( GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) const GraphTraits<DataFlowGraph>::NodesBFSIterator &other)
...@@ -159,7 +281,7 @@ bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==( ...@@ -159,7 +281,7 @@ bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
if (queue_.empty()) return other.queue_.empty(); if (queue_.empty()) return other.queue_.empty();
if ((!queue_.empty()) && (!other.queue_.empty())) { if ((!queue_.empty()) && (!other.queue_.empty())) {
return queue_.front() == other.queue_.front() && return queue_.front() == other.queue_.front() &&
visited_.size() == other.visited_.size(); // here need to check the visited_.size() == other.visited_.size();
// equality of queue and // equality of queue and
// visited. Just a light but week implementation. // visited. Just a light but week implementation.
} }
...@@ -174,10 +296,10 @@ GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator( ...@@ -174,10 +296,10 @@ GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
for (auto *x : source) stack_.push(x); for (auto *x : source) stack_.push(x);
} }
// GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator( GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
// GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept
// : stack_(std::move(other.stack_)), : stack_(std::move(other.stack_)),
// visited_(std::move(other.visited_)) {} visited_(std::move(other.visited_)) {}
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator( GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) const GraphTraits<DataFlowGraph>::NodesDFSIterator &other)
...@@ -339,7 +461,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -339,7 +461,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
std::vector<Node *> op_nodes; std::vector<Node *> op_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
if (node.type() == Node::Type::kValue || node.deleted()) { if (node.type() == Node::Type::kValue || node.deleted()) {
continue; continue;
} }
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/inference/analysis/graph_traits.h" #include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -41,19 +42,43 @@ namespace analysis { ...@@ -41,19 +42,43 @@ namespace analysis {
*/ */
struct DataFlowGraph { struct DataFlowGraph {
NodeMap nodes; NodeMap nodes;
std::vector<Node *> inputs; // inputs and outputs are deduced from the graph.
std::vector<Node *> outputs; // Used to interact with IR.
const framework::ir::Graph *ir_graph{nullptr};
// Extract inputs and outputs of the graph. // Extract inputs and outputs of the graph.
void Build(); void Build();
void Build(const framework::proto::ProgramDesc &prog);
// Build a graph from ir::Graph.
void Build(const framework::ir::Graph &graph);
// Get an attribute.
AnyAttr &Attr(const std::string &key) { return attrs_[key]; }
// Output a DOT graph file for debug. // Output a DOT graph file for debug.
std::string DotString() const; std::string DotString() const;
std::string HumanReadableInfo(bool show_values = true, std::string HumanReadableInfo(bool show_values = true,
bool show_functions = true) const; bool show_functions = true) const;
const std::vector<Node *> &inputs() const {
PADDLE_ENFORCE(!inputs_.empty(),
"No inputs are deduced, need to Build() first.");
return inputs_;
}
const std::vector<Node *> &outputs() const {
PADDLE_ENFORCE(!outputs_.empty(),
"No outputs are deduced, need to Build() first.");
return outputs_;
}
private: private:
mutable std::vector<Node *> inputs_;
mutable std::vector<Node *> outputs_;
std::unordered_map<std::string, AnyAttr> attrs_;
// Remove duplicate edges and so on. // Remove duplicate edges and so on.
void Clean(); void Clean();
}; };
...@@ -70,7 +95,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -70,7 +95,7 @@ struct GraphTraits<DataFlowGraph> {
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesBFSIterator() = default; NodesBFSIterator() = default;
explicit NodesBFSIterator(const std::vector<Node *> &source); explicit NodesBFSIterator(const std::vector<Node *> &source);
// NodesBFSIterator(NodesBFSIterator &&other) noexcept; NodesBFSIterator(NodesBFSIterator &&other) noexcept;
// NOTE Heavy to use. // NOTE Heavy to use.
NodesBFSIterator(const NodesBFSIterator &other); NodesBFSIterator(const NodesBFSIterator &other);
...@@ -93,8 +118,8 @@ struct GraphTraits<DataFlowGraph> { ...@@ -93,8 +118,8 @@ struct GraphTraits<DataFlowGraph> {
struct NodesDFSIterator struct NodesDFSIterator
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesDFSIterator() = default; NodesDFSIterator() = default;
explicit NodesDFSIterator(const std::vector<Node *> &source); NodesDFSIterator(const std::vector<Node *> &source);
// NodesDFSIterator(NodesDFSIterator &&other) noexcept; NodesDFSIterator(NodesDFSIterator &&other) noexcept;
NodesDFSIterator(const NodesDFSIterator &other); NodesDFSIterator(const NodesDFSIterator &other);
Node &operator*(); Node &operator*();
...@@ -116,7 +141,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -116,7 +141,7 @@ struct GraphTraits<DataFlowGraph> {
struct NodesTSIterator struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesTSIterator() = default; NodesTSIterator() = default;
explicit NodesTSIterator(const std::vector<Node *> &source); NodesTSIterator(const std::vector<Node *> &source);
NodesTSIterator(NodesTSIterator &&other) NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) { : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0; other.cursor_ = 0;
...@@ -138,7 +163,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -138,7 +163,7 @@ struct GraphTraits<DataFlowGraph> {
size_t cursor_{0}; size_t cursor_{0};
}; };
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {} explicit GraphTraits(const DataFlowGraph &graph) : graph_(graph) {}
// default use BFS to visit the nodes. // default use BFS to visit the nodes.
iterator_range<NodesBFSIterator> nodes() { iterator_range<NodesBFSIterator> nodes() {
...@@ -156,20 +181,20 @@ struct GraphTraits<DataFlowGraph> { ...@@ -156,20 +181,20 @@ struct GraphTraits<DataFlowGraph> {
private: private:
NodesBFSIterator nodes_bfs_begin() { NodesBFSIterator nodes_bfs_begin() {
return NodesBFSIterator(graph_->inputs); return NodesBFSIterator(graph_.inputs());
} }
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); } NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
NodesDFSIterator nodes_dfs_begin() { NodesDFSIterator nodes_dfs_begin() {
return NodesDFSIterator(graph_->inputs); return NodesDFSIterator(graph_.inputs());
} }
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); } NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_->inputs); } NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_.inputs()); }
NodesTSIterator nodes_ts_end() { return NodesTSIterator(); } NodesTSIterator nodes_ts_end() { return NodesTSIterator(); }
private: private:
DataFlowGraph *graph_; const DataFlowGraph &graph_;
}; };
// Extract the inputs and outputs of a graph. The inputs and outputs of a // Extract the inputs and outputs of a graph. The inputs and outputs of a
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle { namespace paddle {
...@@ -24,20 +25,18 @@ TEST(DataFlowGraph, BFS) { ...@@ -24,20 +25,18 @@ TEST(DataFlowGraph, BFS) {
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
dfg.Build(); dfg.Build();
for (auto *in : dfg.inputs) { for (auto* in : dfg.inputs()) {
LOG(INFO) << "inputs: " << in->name() << " " LOG(INFO) << "inputs: " << in->name() << " "
<< static_cast<int>(in->type()); << static_cast<int>(in->type());
} }
for (auto *out : dfg.outputs) { for (auto* out : dfg.outputs()) {
LOG(INFO) << "outputs: " << out->name() << " " LOG(INFO) << "outputs: " << out->name() << " "
<< static_cast<int>(out->type()); << static_cast<int>(out->type());
} }
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes();
size_t count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto& node : GraphTraits<DataFlowGraph>(dfg).nodes()) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << node.name();
++count; ++count;
} }
ASSERT_EQ(count, dfg.nodes.size()); ASSERT_EQ(count, dfg.nodes.size());
...@@ -45,13 +44,11 @@ TEST(DataFlowGraph, BFS) { ...@@ -45,13 +44,11 @@ TEST(DataFlowGraph, BFS) {
TEST(DataFlowGraph, DFS) { TEST(DataFlowGraph, DFS) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); DataFlowGraph dfg;
dfg.Build(); dfg.Build(desc);
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS();
size_t count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto& node : GraphTraits<DataFlowGraph>(dfg).nodes_in_DFS()) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << node.name();
++count; ++count;
} }
ASSERT_EQ(count, dfg.nodes.size()); ASSERT_EQ(count, dfg.nodes.size());
...@@ -74,21 +71,17 @@ TEST(DataFlowGraph, TS) { ...@@ -74,21 +71,17 @@ TEST(DataFlowGraph, TS) {
DataFlowGraph graph; DataFlowGraph graph;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
auto *node = graph.nodes.Create(Node::Type::kValue); auto* node = graph.nodes.Create(Node::Type::kValue);
node->SetName("node-" + std::to_string(i)); node->SetName("node-" + std::to_string(i));
} }
auto add_link = [&](int i, int j) { auto add_link = [&](int i, int j) {
Node *source = graph.nodes.GetMutable(i); Node* source = graph.nodes.GetMutable(i);
Node *target = graph.nodes.GetMutable(j); Node* target = graph.nodes.GetMutable(j);
target->inlinks.push_back(source); target->inlinks.push_back(source);
source->outlinks.push_back(target); source->outlinks.push_back(target);
}; };
graph.inputs.push_back(graph.nodes.GetMutable(0));
graph.inputs.push_back(graph.nodes.GetMutable(1));
graph.inputs.push_back(graph.nodes.GetMutable(2));
add_link(0, 4); add_link(0, 4);
add_link(0, 5); add_link(0, 5);
add_link(1, 6); add_link(1, 6);
...@@ -97,8 +90,9 @@ TEST(DataFlowGraph, TS) { ...@@ -97,8 +90,9 @@ TEST(DataFlowGraph, TS) {
add_link(4, 7); add_link(4, 7);
add_link(4, 3); add_link(4, 3);
add_link(7, 3); add_link(7, 3);
graph.Build();
auto its = GraphTraits<DataFlowGraph>(&graph).nodes_in_TS(); auto its = GraphTraits<DataFlowGraph>(graph).nodes_in_TS();
std::vector<int> sorted_ids; std::vector<int> sorted_ids;
for (auto it = its.begin(); it != its.end(); ++it) { for (auto it = its.begin(); it != its.end(); ++it) {
LOG(INFO) << it->name(); LOG(INFO) << it->name();
...@@ -122,6 +116,50 @@ TEST(DataFlowGraph, TS) { ...@@ -122,6 +116,50 @@ TEST(DataFlowGraph, TS) {
assert_positive_sequence_pair(4, 7); assert_positive_sequence_pair(4, 7);
} }
TEST(DataFlowGraph, Build_ProgramDesc) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
DataFlowGraph graph;
graph.Build(desc);
ASSERT_EQ(graph.nodes.size(), 38UL);
}
void SetOp(framework::ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Xs", outputs);
}
TEST(DataFlowGraph, Build_IR_Graph) {
framework::ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(framework::proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
DataFlowGraph graph;
framework::ir::Graph ir_graph(prog);
graph.Build(ir_graph);
ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size());
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -52,18 +52,15 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { ...@@ -52,18 +52,15 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool DataFlowGraphToFluidPass::Finalize() { return true; } bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
LOG(INFO) << "graph.inputs " << graph->inputs.size(); // FilterRedundantOutputOfSubGraph(graph);
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
if (node.deleted()) continue; if (node.deleted()) continue;
switch (node.type()) { switch (node.type()) {
case Node::Type::kFunction: { case Node::Type::kFunction: {
LOG(INFO) << "add function " << node.repr();
AddFluidOp(&node); AddFluidOp(&node);
} break; } break;
case Node::Type::kFunctionBlock: { case Node::Type::kFunctionBlock: {
LOG(INFO) << "add engine op " << node.repr() << " , "
<< static_cast<FunctionBlock *>(&node)->subgraph.size();
AddEngineOp(&node); AddEngineOp(&node);
} break; } break;
default: default:
...@@ -75,15 +72,27 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -75,15 +72,27 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc()); PADDLE_ENFORCE(node);
PADDLE_ENFORCE(node->IsFunction());
PADDLE_ENFORCE(node->pb_desc() || !node->pb_msg().empty(),
"node has invalid protobuf repr.");
// currently only the main block is analyzed. // currently only the main block is analyzed.
PADDLE_ENFORCE(desc_);
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto *op = main_block->add_ops(); auto *op = main_block->add_ops();
*op = *ori_op; // copy the attributes, by default, these will not be changed
if (node->pb_desc()) {
auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc());
*op =
*ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase. // by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt // The inputs and outputs of the existing ops are not changed by tensorrt
// subgraph pass. // subgraph pass.
// NOTE It might be changed by other passes in the long run. // NOTE It might be changed by other passes in the long run.
} else {
op->ParseFromString(node->pb_msg());
}
} }
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
...@@ -220,10 +229,9 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { ...@@ -220,10 +229,9 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
framework::BlockDesc block_desc(nullptr, &proto); framework::BlockDesc block_desc(nullptr, &proto);
block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
LOG(INFO) << "origin variable size: " VLOG(4) << "origin variable size: "
<< argument_->origin_program_desc->blocks(0).vars().size(); << argument_->origin_program_desc->blocks(0).vars().size();
LOG(INFO) << "transformed variable size: " VLOG(4) << "transformed variable size: " << block_desc.Proto()->vars().size();
<< block_desc.Proto()->vars().size();
// copy ops. // copy ops.
for (auto *node : block_node->subgraph) { for (auto *node : block_node->subgraph) {
...@@ -257,7 +265,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -257,7 +265,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const { Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, FLAGS_IA_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger")); "data_flow_graph_to_fluid_graphviz_debugger"));
} }
......
...@@ -29,7 +29,7 @@ void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { ...@@ -29,7 +29,7 @@ void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png"; auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png";
std::string message; std::string message;
LOG(INFO) << "draw to " << png_path; VLOG(3) << "draw to " << png_path;
ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message); ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message);
} }
......
...@@ -52,72 +52,7 @@ bool FluidToDataFlowGraphPass::Finalize() { return true; } ...@@ -52,72 +52,7 @@ bool FluidToDataFlowGraphPass::Finalize() { return true; }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(desc_); PADDLE_ENFORCE(desc_);
// insert vars graph->Build(*desc_);
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std::unordered_map<std::string, size_t> var2id;
auto &main_block = desc_->blocks(framework::kRootBlockIndex);
for (int i = 0; i < main_block.vars_size(); i++) {
const auto &var = main_block.vars(i);
auto *v = graph->nodes.Create(Node::Type::kValue);
v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id();
}
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std::unordered_set<Node *> unique_written_vars;
for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i);
auto *o = graph->nodes.Create(Node::Type::kFunction);
o->SetName(op.type());
static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs
for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o);
o->inlinks.push_back(in);
}
}
for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
if (unique_written_vars.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = graph->nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] =
out_alias->id(); // update variable's alias Node
LOG(INFO) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias;
}
out->inlinks.push_back(o);
o->outlinks.push_back(out);
unique_written_vars.insert(out);
}
}
}
// Analysis and extract the inputs and outputs of this graph.
graph->Build();
} }
namespace { namespace {
...@@ -133,7 +68,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -133,7 +68,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const { Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, "fluid-to-dfg-debuger")); FLAGS_IA_graphviz_log_root, "fluid-to-dfg-debuger"));
} }
} // namespace analysis } // namespace analysis
......
...@@ -30,7 +30,7 @@ TEST(FluidToDataFlowGraphPass, Test) { ...@@ -30,7 +30,7 @@ TEST(FluidToDataFlowGraphPass, Test) {
ASSERT_EQ(argument.main_dfg->nodes.size(), 38UL); ASSERT_EQ(argument.main_dfg->nodes.size(), 38UL);
pass.Finalize(); pass.Finalize();
ASSERT_FALSE(argument.main_dfg->DotString().empty()); ASSERT_FALSE(argument.main_dfg->DotString().empty());
EXPECT_FALSE(argument.main_dfg->inputs.empty()); EXPECT_FALSE(argument.main_dfg->inputs().empty());
} }
} // namespace analysis } // namespace analysis
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class FluidToIrPass final : public DataFlowGraphPass {
public:
FluidToIrPass() = default;
bool Initialize(Argument *argument) override {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called";
}
// set fluid model program path
if (!argument->fluid_model_program_path) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
argument->fluid_model_program_path.reset(
new std::string(*argument->fluid_model_dir + "/__model__"));
}
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
// Create main data flow graph.
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
}
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the
// address, will segfault if the original ProgramDesc destroys.
auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer();
ir_program_p = new framework::ProgramDesc(program);
argument_ = argument;
return true;
}
bool Finalize() override { return true; }
void Run(DataFlowGraph *graph) override {
// Call all the IR Passes
IRPassManager ir_passes(*static_cast<framework::ProgramDesc *>(
argument_->main_dfg->Attr("ir_program_desc").Pointer()));
ir_passes.Apply(std::vector<std::string>(
{// Manual update the passes here.
"graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass",
"fc_fuse_pass", "graph_viz_pass"}));
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
}
std::string repr() const override { return "fluid-to-ir-pass"; }
private:
Argument *argument_{nullptr};
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST(FluidToIrPass, Test) {
FluidToIrPass pass;
Argument argument(FLAGS_inference_model_dir);
pass.Initialize(&argument);
pass.Run(argument.main_dfg.get());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <sys/stat.h>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <string> #include <string>
...@@ -151,6 +152,23 @@ static framework::proto::ProgramDesc LoadProgramDesc( ...@@ -151,6 +152,23 @@ static framework::proto::ProgramDesc LoadProgramDesc(
return program_desc; return program_desc;
} }
static bool FileExists(const std::string &filepath) {
std::ifstream file(filepath);
bool exists = file.is_open();
file.close();
return exists;
}
static bool PathExists(const std::string &path) {
struct stat statbuf;
if (stat(path.c_str(), &statbuf) != -1) {
if (S_ISDIR(statbuf.st_mode)) {
return true;
}
}
return false;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
namespace paddle {
namespace inference {
namespace analysis {
IRPassManager::IRPassManager(const ProgramDesc& program) {
graph_.reset(new framework::ir::Graph(program));
}
void IRPassManager::Apply(const std::vector<std::string>& passes) {
graph_->Set("graph_viz_path", new std::string("./1.dot"));
// Apply all the passes
std::string pre_pass;
for (const std::string& pass_name : passes) {
LOG(WARNING) << "Running IR pass [" << pass_name << "]";
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") {
std::string dot_file_path =
"ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot";
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
}
graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name;
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines IRPassManager, it helps control the passes in IR. Inference
* phrase will load the model program and parameters from disk, that is quite
* different from the training phase.
* This manager will control the Passes and make the passes in IR work smoothly
* for inference.
*/
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace inference {
namespace analysis {
using framework::ProgramDesc;
class IRPassManager final {
public:
IRPassManager(const ProgramDesc& program);
void Apply(const std::vector<std::string>& passes);
framework::ir::Graph& graph() const { return *graph_; }
private:
std::unique_ptr<framework::ir::Graph> graph_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -35,19 +35,21 @@ void ModelStorePass::Run(DataFlowGraph *x) { ...@@ -35,19 +35,21 @@ void ModelStorePass::Run(DataFlowGraph *x) {
std::stringstream ss; std::stringstream ss;
// NOTE these commands only works on linux. // NOTE these commands only works on linux.
ss << "mkdir -p " << *argument_->model_output_store_path; ss << "mkdir -p " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str(); VLOG(3) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
ss.str(""); ss.str("");
ss << "cp " << *argument_->fluid_model_dir << "/*" ss << "cp " << *argument_->fluid_model_dir << "/*"
<< " " << *argument_->model_output_store_path; << " " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str(); VLOG(3) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
// Store program // Store program
PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc, PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc,
"program desc is not transformed, should call " "program desc is not transformed, should call "
"DataFlowGraphToFluidPass first."); "DataFlowGraphToFluidPass first.");
VLOG(3) << "store analyzed program to "
<< *argument_->model_output_store_path;
const std::string program_output_path = const std::string program_output_path =
*argument_->model_output_store_path + "/__model__"; *argument_->model_output_store_path + "/__model__";
std::ofstream file(program_output_path, std::ios::binary); std::ofstream file(program_output_path, std::ios::binary);
...@@ -58,6 +60,8 @@ void ModelStorePass::Run(DataFlowGraph *x) { ...@@ -58,6 +60,8 @@ void ModelStorePass::Run(DataFlowGraph *x) {
file.write(serialized_message.c_str(), serialized_message.size()); file.write(serialized_message.c_str(), serialized_message.size());
} }
bool ModelStorePass::Finalize() { return true; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -44,6 +44,8 @@ class ModelStorePass : public DataFlowGraphPass { ...@@ -44,6 +44,8 @@ class ModelStorePass : public DataFlowGraphPass {
model in the disk, and that model can be reloaded for prediction again.)DD"; model in the disk, and that model can be reloaded for prediction again.)DD";
} }
bool Finalize() override;
private: private:
Argument* argument_{nullptr}; Argument* argument_{nullptr};
}; };
......
...@@ -30,7 +30,7 @@ TEST(DFG_StorePass, test) { ...@@ -30,7 +30,7 @@ TEST(DFG_StorePass, test) {
argument.model_output_store_path.reset( argument.model_output_store_path.reset(
new std::string("./_dfg_store_pass_tmp")); new std::string("./_dfg_store_pass_tmp"));
// disable storage in alalyzer // disable storage in alalyzer
FLAGS_inference_analysis_output_storage_path = ""; FLAGS_IA_output_storage_path = "";
analyzer.Run(&argument); analyzer.Run(&argument);
ModelStorePass pass; ModelStorePass pass;
......
...@@ -38,7 +38,7 @@ namespace analysis { ...@@ -38,7 +38,7 @@ namespace analysis {
class NodeMap; class NodeMap;
// A helper class to maintain the status from Pass. // A helper class to maintain the status from Pass.
struct NodeAttr { struct AnyAttr {
using any_t = using any_t =
boost::variant<bool, float, int32_t, int64_t, void *, std::string>; boost::variant<bool, float, int32_t, int64_t, void *, std::string>;
// NOTE T should be a primary type or a struct combined by several primary // NOTE T should be a primary type or a struct combined by several primary
...@@ -54,10 +54,9 @@ struct NodeAttr { ...@@ -54,10 +54,9 @@ struct NodeAttr {
void *&Pointer() { return As<void *>(); } void *&Pointer() { return As<void *>(); }
std::string &String() { return As<std::string>(); } std::string &String() { return As<std::string>(); }
private:
template <typename T> template <typename T>
T &As() { T &As() {
if (type_index_ == typeid(NodeAttr)) { if (type_index_ == typeid(AnyAttr)) {
type_index_ = typeid(T); type_index_ = typeid(T);
any_data_ = T(); any_data_ = T();
} else { } else {
...@@ -68,7 +67,7 @@ struct NodeAttr { ...@@ -68,7 +67,7 @@ struct NodeAttr {
private: private:
any_t any_data_; any_t any_data_;
std::type_index type_index_{typeid(NodeAttr)}; std::type_index type_index_{typeid(AnyAttr)};
}; };
/* /*
...@@ -105,7 +104,7 @@ class Node { ...@@ -105,7 +104,7 @@ class Node {
// Get an additional attribute and convert it to T data type. NOTE this will // Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists. // silently create a new attribute if not exists.
NodeAttr &attr(const std::string &name) const { return attrs_[name]; } AnyAttr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; } int id() const { return id_; }
...@@ -150,7 +149,7 @@ class Node { ...@@ -150,7 +149,7 @@ class Node {
Type type_{Type::kNone}; Type type_{Type::kNone};
// Mark this node is deleted by some pass. // Mark this node is deleted by some pass.
bool deleted_{false}; bool deleted_{false};
mutable std::unordered_map<std::string, NodeAttr> attrs_; mutable std::unordered_map<std::string, AnyAttr> attrs_;
}; };
class Function; class Function;
......
...@@ -21,19 +21,19 @@ namespace inference { ...@@ -21,19 +21,19 @@ namespace inference {
namespace analysis { namespace analysis {
TEST(NodeAttr, bool) { TEST(NodeAttr, bool) {
NodeAttr x; AnyAttr x;
x.Bool() = true; x.Bool() = true;
ASSERT_EQ(x.Bool(), true); ASSERT_EQ(x.Bool(), true);
} }
TEST(NodeAttr, int32) { TEST(NodeAttr, int32) {
NodeAttr x; AnyAttr x;
x.Int32() = 32; x.Int32() = 32;
ASSERT_EQ(x.Int32(), 32); ASSERT_EQ(x.Int32(), 32);
} }
TEST(NodeAttr, string) { TEST(NodeAttr, string) {
NodeAttr x; AnyAttr x;
x.String() = "Hello"; x.String() = "Hello";
ASSERT_EQ(x.String(), "Hello"); ASSERT_EQ(x.String(), "Hello");
} }
......
...@@ -63,7 +63,7 @@ class Pass { ...@@ -63,7 +63,7 @@ class Pass {
// Human-readable short representation. // Human-readable short representation.
virtual std::string repr() const = 0; virtual std::string repr() const = 0;
// Human-readable long description. // Human-readable long description.
virtual std::string description() const = 0; virtual std::string description() const { return "No DOC"; }
}; };
// NodePass process on any Node types. // NodePass process on any Node types.
......
...@@ -22,7 +22,7 @@ namespace analysis { ...@@ -22,7 +22,7 @@ namespace analysis {
bool PassManager::Initialize(Argument* argument) { bool PassManager::Initialize(Argument* argument) {
argument_ = argument; argument_ = argument;
for (auto& pass : data_) { for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr(); LOG(WARNING) << "Initializing pass [" << pass->repr() << "]";
if (!pass->Initialize(argument)) { if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]"; LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false; return false;
...@@ -33,8 +33,9 @@ bool PassManager::Initialize(Argument* argument) { ...@@ -33,8 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
LOG(INFO) << "Total " << data_.size() << " passes";
for (auto& pass : data_) { for (auto& pass : data_) {
VLOG(4) << "Running pass [" << pass->repr() << "]"; LOG(WARNING) << "Running pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get()); pass->Run(argument_->main_dfg.get());
} }
} }
...@@ -42,8 +43,7 @@ void DfgPassManager::RunAll() { ...@@ -42,8 +43,7 @@ void DfgPassManager::RunAll() {
void NodePassManager::RunAll() { void NodePassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
auto trait = auto trait = GraphTraits<DataFlowGraph>(*argument_->main_dfg).nodes_in_DFS();
GraphTraits<DataFlowGraph>(argument_->main_dfg.get()).nodes_in_DFS();
for (auto& node : trait) { for (auto& node : trait) {
for (auto& pass : data_) { for (auto& pass : data_) {
pass->Run(&node); pass->Run(&node);
......
...@@ -34,7 +34,7 @@ inline void MarkOutLinksInSubGraph(const Function *func) { ...@@ -34,7 +34,7 @@ inline void MarkOutLinksInSubGraph(const Function *func) {
} }
void SubGraphSplitter::MarkNodesInsideSubGraph() { void SubGraphSplitter::MarkNodesInsideSubGraph() {
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes()) {
if (node_inside_subgraph_teller_(&node)) { if (node_inside_subgraph_teller_(&node)) {
node.attr(kMarkerAttrName).Bool() = true; node.attr(kMarkerAttrName).Bool() = true;
if (node.type() == Node::Type::kFunction) { if (node.type() == Node::Type::kFunction) {
...@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { ...@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
if (node.attr(kMarkerAttrName).Bool()) { if (node.attr(kMarkerAttrName).Bool()) {
marked_nodes.push_back(&node); marked_nodes.push_back(&node);
} }
......
...@@ -69,8 +69,8 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass { ...@@ -69,8 +69,8 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass {
}; };
Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const { Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
DFG_GraphvizDrawPass::Config config( DFG_GraphvizDrawPass::Config config(FLAGS_IA_graphviz_log_root,
FLAGS_inference_analysis_graphviz_log_root, "tensorrt_marked_node"); "tensorrt_marked_node");
return new DfgDebuggerPass(config); return new DfgDebuggerPass(config);
} }
bool TensorRTSubgraphNodeMarkPass::Finalize() { return true; } bool TensorRTSubgraphNodeMarkPass::Finalize() { return true; }
......
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
void CompareTensorRTWithFluid(bool enable_tensorrt) { void CompareTensorRTWithFluid(bool enable_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = enable_tensorrt; FLAGS_IA_enable_tensorrt_subgraph_engine = enable_tensorrt;
//# 1. Create PaddlePredictor with a config. //# 1. Create PaddlePredictor with a config.
NativeConfig config0; NativeConfig config0;
......
...@@ -35,9 +35,14 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -35,9 +35,14 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim]."); "The shape of Bias must be [1, dim].");
} else if (bias_dims.size() == 1) {
PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
"The shape of Bias must be [1, dim].");
}
} }
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册