“fa2e9907823c0e4b5b11de0bab9fd484c86526a3”上不存在“develop/doc/api/v1/trainer_config_helpers/activations.html”
提交 efe88ab9 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into disable_prelu_test_local

...@@ -102,7 +102,6 @@ set(COMMON_FLAGS ...@@ -102,7 +102,6 @@ set(COMMON_FLAGS
-fno-omit-frame-pointer -fno-omit-frame-pointer
-Wall -Wall
-Wextra -Wextra
-Werror
-Wnon-virtual-dtor -Wnon-virtual-dtor
-Wdelete-non-virtual-dtor -Wdelete-non-virtual-dtor
-Wno-unused-parameter -Wno-unused-parameter
...@@ -115,6 +114,11 @@ set(COMMON_FLAGS ...@@ -115,6 +114,11 @@ set(COMMON_FLAGS
-Wno-error=terminate # Warning in PADDLE_ENFORCE -Wno-error=terminate # Warning in PADDLE_ENFORCE
) )
# https://github.com/PaddlePaddle/Paddle/issues/12773
if (NOT WIN32)
list(APPEND COMMON_FLAGS -Werror)
endif()
set(GPU_COMMON_FLAGS set(GPU_COMMON_FLAGS
-fPIC -fPIC
-fno-omit-frame-pointer -fno-omit-frame-pointer
......
...@@ -28,7 +28,7 @@ def get_symbol(num_classes=10, **kwargs): ...@@ -28,7 +28,7 @@ def get_symbol(num_classes=10, **kwargs):
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null. Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own AnyAttr. There is a op field in AnyAttr class, when a Symbol represents Variable(often input data), the op field is null.
Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph. Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph.
......
...@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', ' ...@@ -78,7 +78,7 @@ paddle.fluid.io.load_vars ArgSpec(args=['executor', 'dirname', 'main_program', '
paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True)) paddle.fluid.io.save_inference_model ArgSpec(args=['dirname', 'feeded_var_names', 'target_vars', 'executor', 'main_program', 'model_filename', 'params_filename', 'export_for_deployment'], varargs=None, keywords=None, defaults=(None, None, None, True))
paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.load_inference_model ArgSpec(args=['dirname', 'executor', 'model_filename', 'params_filename', 'pserver_endpoints'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.io.get_inference_program ArgSpec(args=['target_vars', 'main_program'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False)) paddle.fluid.initializer.ConstantInitializer.__init__ ArgSpec(args=['self', 'value', 'force_cpu'], varargs=None, keywords=None, defaults=(0.0, False))
paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0)) paddle.fluid.initializer.UniformInitializer.__init__ ArgSpec(args=['self', 'low', 'high', 'seed'], varargs=None, keywords=None, defaults=(-1.0, 1.0, 0))
......
...@@ -5,8 +5,12 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper) ...@@ -5,8 +5,12 @@ cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper) cc_library(graph_viz_pass SRCS graph_viz_pass.cc DEPS graph pass graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits) cc_library(graph_pattern_detecter SRCS graph_pattern_detecter.cc DEPS graph graph_helper graph_traits)
cc_library(fc_fuse_pass SRCS fc_fuse_pass.cc DEPS graph graph_pattern_detecter)
cc_library(infer_clean_graph_pass SRCS infer_clean_graph_pass.cc DEPS graph pass)
cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper)
cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry)
cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry)
cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter) cc_test(test_graph_pattern_detecter SRCS graph_pattern_detecter_tester.cc DEPS graph_pattern_detecter)
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass graph_pattern_detecter graph pass graph_traits framework_proto)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
bool VarOutLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
void BuildFCPattern(PDPattern* pattern) {
// make sure the selected MUL op has one input argument is a parameter.
auto* mul_parameter_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() == 1UL &&
node->outputs.front()->Op()->Type() == "mul" && node->Var() &&
node->Var()->Persistable(); // check is a parameter
},
"mul_weight" /*name*/);
auto* mul_tmp_input_var = pattern->NewNode(
[](Node* node) {
bool result =
node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
!node->Var()->Persistable(); // this input is not an parameter.
if (!result) return false;
// check whether one output is MUL op.
for (auto* op : node->outputs) {
if (op->IsOp() && op->Op()->Type() == "mul") return true;
}
return false;
},
"mul_tmp_var" /*name*/);
// select a MUL op
auto* mul_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && // start from an Op
node->Op()->Type() == "mul"; // type is mul
// the output should be consumed only by one element_add, that check
// leaves in a Var PDNode.
},
"mul" /*name*/);
// make sure the MUL op's output has only one consumer and links to an
// ELEMENTWISE_ADD op.
auto* mul_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && // starts from a Var
node->outputs.size() == 1UL && // only has one consumer
node->outputs.front()->IsOp() && // check basic logic
node->Var() && // not a ControlDepVar
node->outputs.front()->Op()->Type() ==
"elementwise_add"; // a very strong validation
},
"mul_out");
// this check is not essential, just to make the corresponding variable Node
// retrival easier.
auto* elementwise_add_tmp_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
VarOutLinksToOp(node, "elementwise_add");
},
"elementwise_add_tmpvar");
// select an ELEMENTWISE_ADD op
auto* elementwise_add_op = pattern->NewNode(
[](Node* node) {
return node->IsOp() && node->Op()->Type() == "elementwise_add";
},
"elementwise_add" /*name*/);
// get the ELEMENTWISE_ADD op's output
auto* elementwise_add_out_var = pattern->NewNode(
[](Node* node) {
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
node->inputs.front()->Op()->Type() == "elementwise_add";
},
"elementwise_add_out");
pattern->AddEdge(mul_parameter_var, mul_op);
pattern->AddEdge(mul_tmp_input_var, mul_op);
pattern->AddEdge(mul_op, mul_out_var);
pattern->AddEdge(mul_out_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_tmp_var, elementwise_add_op);
pattern->AddEdge(elementwise_add_op, elementwise_add_out_var);
}
// Replace the node `from` in the links to `to`
bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
for (auto*& n : *links) {
if (n == from) {
n = to;
return true;
}
}
return false;
}
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
std::unordered_set<Node*> nodes2delete;
GraphPatternDetecter gpd;
BuildFCPattern(gpd.mutable_pattern());
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
auto handler = [&](const GraphPatternDetecter::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the
// scenerio.
// FC's fusion is simple, just op fuse, no need to process the
// parameters.
GET_NODE(mul_tmp_var); // x
GET_NODE(mul_weight); // Y
GET_NODE(elementwise_add_tmpvar); // bias
GET_NODE(elementwise_add_out); // Out
GET_NODE(mul); // MUL op
GET_NODE(elementwise_add); // ELEMENT_ADD op
GET_NODE(mul_out); // tmp
#undef GET_NODE
// Create an FC Node.
OpDesc desc;
std::string fc_x_in = mul_tmp_var->Name();
std::string fc_Y_in = mul_weight->Name();
std::string fc_bias_in = elementwise_add_tmpvar->Name();
std::string fc_out = elementwise_add_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out}));
desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
fc_node->inputs =
std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
fc_node->outputs.push_back(elementwise_add_out);
// Update link relatons
PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
elementwise_add, fc_node));
PADDLE_ENFORCE(
LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
// Drop old nodes
graph->RemoveNode(mul);
graph->RemoveNode(elementwise_add);
graph->RemoveNode(mul_out); // tmp variable
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fc_fuse_pass, paddle::framework::ir::FCFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detecter.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
*/
class FCFusePass : public Pass {
public:
virtual ~FCFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <gtest/gtest.h>
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Ys", outputs);
}
// a->OP0->b
// a->OP1->c
// (b, c)->mul->d
// (d, e)->elementwise_add->f
ProgramDesc BuildProgramDesc() {
ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
return prog;
}
TEST(FCFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
auto pass = PassRegistry::Instance().Get("fc_fuse_pass");
int pre_nodes = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
int after_nodes = graph->Nodes().size();
// Remove 3 Nodes: MUL,ELEMENTWISE_ADD, mul_out
// Add 1 Node: FC
EXPECT_EQ(pre_nodes - 2, after_nodes);
// Assert fc op in newly generated graph
int fc_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "fc") {
++fc_count;
}
}
EXPECT_EQ(fc_count, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fc_fuse_pass);
...@@ -98,11 +98,13 @@ class Graph { ...@@ -98,11 +98,13 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc)); return AddNode(new ir::Node(var_desc));
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc)); return AddNode(new ir::Node(op_desc));
} }
...@@ -134,6 +136,14 @@ class Graph { ...@@ -134,6 +136,14 @@ class Graph {
return ret; return ret;
} }
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
const ProgramDesc &program() const { return program_; }
private: private:
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
...@@ -143,12 +153,6 @@ class Graph { ...@@ -143,12 +153,6 @@ class Graph {
return node; return node;
} }
void RemoveNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) != node_set_.end());
node_set_.erase(node);
nodes_.erase(node);
}
// NOTE: program_ shouldn't be exposed to user. // NOTE: program_ shouldn't be exposed to user.
const ProgramDesc &program_; const ProgramDesc &program_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
......
...@@ -25,12 +25,30 @@ namespace paddle { ...@@ -25,12 +25,30 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
size_t PDPattern::id_ = 0UL;
PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) { PDNode* PDPattern::NewNode(PDNode::teller_t&& teller, const std::string& name) {
if (!name.empty()) {
PADDLE_ENFORCE_EQ(node_map_.count(name), 0,
"PDNode's name should be unique, get duplicate [%s]",
name);
}
nodes_.emplace_back(new PDNode(std::move(teller), name)); nodes_.emplace_back(new PDNode(std::move(teller), name));
auto* cur = nodes_.back().get(); auto* cur = nodes_.back().get();
node_map_[name] = cur;
return cur; return cur;
} }
PDNode* PDPattern::RetriveNode(const std::string& id) const {
auto it = node_map_.find(id);
if (it == node_map_.end()) {
return nullptr;
}
return it->second;
}
void PDPattern::AddEdge(PDNode* a, PDNode* b) { void PDPattern::AddEdge(PDNode* a, PDNode* b) {
PADDLE_ENFORCE(a); PADDLE_ENFORCE(a);
PADDLE_ENFORCE(b); PADDLE_ENFORCE(b);
...@@ -51,15 +69,18 @@ void GraphPatternDetecter::operator()(Graph* graph, ...@@ -51,15 +69,18 @@ void GraphPatternDetecter::operator()(Graph* graph,
} }
bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) { bool GraphPatternDetecter::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
for (auto& node : GraphTraits::DFS(graph)) { for (auto& node : GraphTraits::DFS(graph)) {
for (const auto& pdnode : pattern_.nodes()) { for (const auto& pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) { if (pdnode->Tell(&node)) {
VLOG(4) << "pdnode " << pdnode->name() << " marked";
pdnodes2nodes_[pdnode.get()].insert(&node); pdnodes2nodes_[pdnode.get()].insert(&node);
} }
} }
} }
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
} }
...@@ -67,10 +88,20 @@ struct HitGroup { ...@@ -67,10 +88,20 @@ struct HitGroup {
std::unordered_map<PDNode*, Node*> roles; std::unordered_map<PDNode*, Node*> roles;
bool Match(Node* node, PDNode* pat) { bool Match(Node* node, PDNode* pat) {
if (nodes_.count(node)) {
if (!roles.count(pat)) return false;
return roles[pat] == node;
}
return !roles.count(pat) || roles.at(pat) == node; return !roles.count(pat) || roles.at(pat) == node;
} }
void Register(Node* node, PDNode* pat) { roles[pat] = node; } void Register(Node* node, PDNode* pat) {
roles[pat] = node;
nodes_.insert(node);
}
private:
std::unordered_set<Node*> nodes_;
}; };
// Tell whether Node a links to b. // Tell whether Node a links to b.
...@@ -104,6 +135,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -104,6 +135,7 @@ GraphPatternDetecter::DetectPatterns() {
// Extend a PDNode to subgraphs by deducing the connection relations defined // Extend a PDNode to subgraphs by deducing the connection relations defined
// in edges of PDNodes. // in edges of PDNodes.
for (const auto& edge : pattern_.edges()) { for (const auto& edge : pattern_.edges()) {
VLOG(4) << "check " << edge.first->name() << " -> " << edge.second->name();
// Each role has two PDNodes, which indicates two roles. // Each role has two PDNodes, which indicates two roles.
// Detect two Nodes that can match these two roles and they are connected. // Detect two Nodes that can match these two roles and they are connected.
auto& pre_groups = bi_records[step % 2]; auto& pre_groups = bi_records[step % 2];
...@@ -127,6 +159,7 @@ GraphPatternDetecter::DetectPatterns() { ...@@ -127,6 +159,7 @@ GraphPatternDetecter::DetectPatterns() {
} }
} }
} }
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
} }
for (auto& group : bi_records[step % 2]) { for (auto& group : bi_records[step % 2]) {
......
...@@ -96,7 +96,8 @@ class PDPattern { ...@@ -96,7 +96,8 @@ class PDPattern {
void AddEdge(PDNode* a, PDNode* b); void AddEdge(PDNode* a, PDNode* b);
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = ""); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID());
PDNode* RetriveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
const std::vector<edge_t>& edges() const { return edges_; } const std::vector<edge_t>& edges() const { return edges_; }
...@@ -107,8 +108,12 @@ class PDPattern { ...@@ -107,8 +108,12 @@ class PDPattern {
FRIEND_TEST(PDPattern, NewNode); FRIEND_TEST(PDPattern, NewNode);
#endif #endif
static std::string NewID() { return "pdnode-" + std::to_string(id_++); }
std::vector<std::unique_ptr<PDNode>> nodes_; std::vector<std::unique_ptr<PDNode>> nodes_;
std::vector<edge_t> edges_; std::vector<edge_t> edges_;
std::unordered_map<std::string, PDNode*> node_map_;
static size_t id_;
}; };
/* /*
......
...@@ -25,6 +25,7 @@ static const char kGraphVizPath[] = "graph_viz_path"; ...@@ -25,6 +25,7 @@ static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>(kGraphVizPath); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
VLOG(3) << "draw IR graph viz to " << graph_viz_path;
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class InferCleanGraphPass : public Pass {
public:
virtual ~InferCleanGraphPass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get());
auto is_valid_node = [](Node* x) {
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
};
std::unordered_set<Node*> invalid_nodes;
for (auto* node : graph->Nodes()) {
if (is_valid_node(node)) {
invalid_nodes.insert(node);
}
}
// remove nodes from the graph.
for (auto* node : invalid_nodes) {
graph->RemoveNode(node);
}
// clean edges.
for (auto* node : graph->Nodes()) {
CleanEdges(&node->inputs, invalid_nodes);
CleanEdges(&node->outputs, invalid_nodes);
}
return graph;
}
void CleanEdges(std::vector<Node*>* nodes,
const std::unordered_set<Node*>& to_remove) const {
auto it = std::remove_if(nodes->begin(), nodes->end(),
[&](Node* x) { return to_remove.count(x); });
nodes->erase(it, nodes->end());
}
};
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(infer_clean_graph_pass,
paddle::framework::ir::InferCleanGraphPass);
...@@ -34,14 +34,15 @@ class Node { ...@@ -34,14 +34,15 @@ class Node {
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(var_desc), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable) {} type_(Type::kVariable) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(op_desc), op_desc_(new OpDesc(*op_desc)), // TODO(panyx0718) the pointer in the
// original OpDesc might go out.
type_(Type::kOperation) {} type_(Type::kOperation) {}
Type NodeType() const { return type_; } Type NodeType() const { return type_; }
...@@ -50,12 +51,12 @@ class Node { ...@@ -50,12 +51,12 @@ class Node {
VarDesc* Var() { VarDesc* Var() {
PADDLE_ENFORCE(type_ == Type::kVariable); PADDLE_ENFORCE(type_ == Type::kVariable);
return var_desc_; return var_desc_.get();
} }
OpDesc* Op() { OpDesc* Op() {
PADDLE_ENFORCE(type_ == Type::kOperation); PADDLE_ENFORCE(IsOp());
return op_desc_; return op_desc_.get();
} }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
...@@ -66,8 +67,8 @@ class Node { ...@@ -66,8 +67,8 @@ class Node {
protected: protected:
const std::string name_; const std::string name_;
VarDesc* var_desc_; std::unique_ptr<VarDesc> var_desc_;
OpDesc* op_desc_; std::unique_ptr<OpDesc> op_desc_;
Type type_; Type type_;
private: private:
......
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass)
cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc
analyzer.cc
helper.cc
# passes
fluid_to_data_flow_graph_pass.cc fluid_to_data_flow_graph_pass.cc
data_flow_graph_to_fluid_pass.cc data_flow_graph_to_fluid_pass.cc
dfg_graphviz_draw_pass.cc dfg_graphviz_draw_pass.cc
tensorrt_subgraph_pass.cc tensorrt_subgraph_pass.cc
tensorrt_subgraph_node_mark_pass.cc tensorrt_subgraph_node_mark_pass.cc
analyzer.cc fluid_to_ir_pass.cc
helper.cc model_store_pass.cc
model_store_pass.cc DEPS framework_proto proto_desc ir_pass_manager graph pass)
DEPS framework_proto proto_desc)
cc_test(test_node SRCS node_tester.cc DEPS analysis gflags glog gtest) cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis) cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
...@@ -27,19 +31,30 @@ function (inference_analysis_test TARGET) ...@@ -27,19 +31,30 @@ function (inference_analysis_test TARGET)
endif() endif()
cc_test(${TARGET} cc_test(${TARGET}
SRCS "${analysis_test_SRCS}" SRCS "${analysis_test_SRCS}"
DEPS analysis DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detecter pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING) endif(WITH_TESTING)
endfunction(inference_analysis_test) endfunction(inference_analysis_test)
cc_test(test_analyzer SRCS analyzer_tester.cc DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis
# ir
fc_fuse_pass
graph_viz_pass
infer_clean_graph_pass
graph_pattern_detecter
pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model)
#set_tests_properties(test_analyzer PROPERTIES DEPENDS test_word2vec)
#inference_api_test(test_analyzer SRC analyzer_tester.cc ARGS test_word2vec)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc) inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc) inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc) inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc) inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc) inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc)
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc) inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h" #include "paddle/fluid/inference/analysis/model_store_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h" #include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h" #include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
...@@ -24,14 +25,15 @@ ...@@ -24,14 +25,15 @@
namespace paddle { namespace paddle {
DEFINE_bool(inference_analysis_enable_tensorrt_subgraph_engine, true, DEFINE_bool(IA_enable_tensorrt_subgraph_engine, false,
"Enable subgraph to TensorRT engine for acceleration"); "Enable subgraph to TensorRT engine for acceleration");
DEFINE_string(inference_analysis_graphviz_log_root, "./", DEFINE_bool(IA_enable_ir, false, "Turn on IR support");
DEFINE_string(IA_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs."); "Graphviz debuger for data flow graphs.");
DEFINE_string(inference_analysis_output_storage_path, "", DEFINE_string(IA_output_storage_path, "", "optimized model output path");
"optimized model output path");
namespace inference { namespace inference {
namespace analysis { namespace analysis {
...@@ -40,8 +42,34 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -40,8 +42,34 @@ class DfgPassManagerImpl final : public DfgPassManager {
public: public:
DfgPassManagerImpl() { DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs. // TODO(Superjomn) set the key with pass reprs.
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass); LOG(INFO)
if (FLAGS_inference_analysis_enable_tensorrt_subgraph_engine) { << "-----------------------------------------------------------------";
if (FLAGS_IA_enable_ir) {
AddPass("fluid-to-ir-pass", new FluidToIrPass);
} else {
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
}
TryAddTensorRtPass();
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_IA_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
LOG(INFO)
<< "-----------------------------------------------------------------";
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
VLOG(3) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
}
void TryAddTensorRtPass() {
if (FLAGS_IA_enable_tensorrt_subgraph_engine) {
auto trt_teller = [&](const Node* node) { auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set( std::unordered_set<std::string> teller_set(
{"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"}); {"elementwise_add", "mul", "conv2d", "pool2d", "relu", "softmax"});
...@@ -59,20 +87,6 @@ class DfgPassManagerImpl final : public DfgPassManager { ...@@ -59,20 +87,6 @@ class DfgPassManagerImpl final : public DfgPassManager {
new TensorRTSubgraphNodeMarkPass(trt_teller)); new TensorRTSubgraphNodeMarkPass(trt_teller));
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller)); AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
} }
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_inference_analysis_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, Pass* pass) {
LOG(INFO) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
} }
// Add the graphviz debuger pass if the parent pass has one. // Add the graphviz debuger pass if the parent pass has one.
......
...@@ -43,9 +43,10 @@ namespace paddle { ...@@ -43,9 +43,10 @@ namespace paddle {
// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this // TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this
// flag if not available. // flag if not available.
DECLARE_bool(inference_analysis_enable_tensorrt_subgraph_engine); DECLARE_bool(IA_enable_tensorrt_subgraph_engine);
DECLARE_string(inference_analysis_graphviz_log_root); DECLARE_string(IA_graphviz_log_root);
DECLARE_string(inference_analysis_output_storage_path); DECLARE_string(IA_output_storage_path);
DECLARE_bool(IA_enable_ir);
namespace inference { namespace inference {
namespace analysis { namespace analysis {
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = false; FLAGS_IA_enable_tensorrt_subgraph_engine = false;
Argument argument; Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
...@@ -29,13 +31,73 @@ TEST(Analyzer, analysis_without_tensorrt) { ...@@ -29,13 +31,73 @@ TEST(Analyzer, analysis_without_tensorrt) {
} }
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = true; FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument; Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
void TestWord2vecPrediction(const std::string& model_path) {
NativeConfig config;
config.model_dir = model_path;
config.use_gpu = false;
config.device = 0;
auto predictor =
::paddle::CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
config);
// One single batch
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor;
tensor.shape = std::vector<int>({4, 1});
tensor.data = PaddleBuf(data, sizeof(data));
tensor.dtype = PaddleDType::INT64;
// For simplicity, we set all the slots with the same data.
std::vector<PaddleTensor> slots(4, tensor);
std::vector<PaddleTensor> outputs;
CHECK(predictor->Run(slots, &outputs));
PADDLE_ENFORCE(outputs.size(), 1UL);
// Check the output buffer size and result of each tid.
PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
float result[5] = {0.00129761, 0.00151112, 0.000423564, 0.00108815,
0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
LOG(INFO) << "data: "
<< static_cast<float*>(outputs.front().data.data())[i];
PADDLE_ENFORCE(static_cast<float*>(outputs.front().data.data())[i],
result[i]);
}
}
// Turn on the IR pass supportion, run a real inference and check the result.
TEST(Analyzer, SupportIRPass) {
FLAGS_IA_enable_ir = true;
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
FLAGS_IA_output_storage_path = "./analysis.out";
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(new std::string("./analysis.out"));
Analyzer analyzer;
analyzer.Run(&argument);
// Should get the transformed model stored to ./analysis.out
ASSERT_TRUE(PathExists("./analysis.out"));
// Inference from this path.
TestWord2vecPrediction("./analysis.out");
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
...@@ -19,14 +19,16 @@ limitations under the License. */ ...@@ -19,14 +19,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using ir_node_t = framework::ir::Node;
using ir_graph_t = framework::ir::Graph;
// It is a better idea that the inputs and outputs of this graph is set manually // It is a better idea that the inputs and outputs of this graph is set manually
// before, but there must be a Pass that helps to prune the unnecessary ops that // before, but there must be a Pass that helps to prune the unnecessary ops that
// do not contribute to the given targets, so in this pass, analysis and get the // do not contribute to the given targets, so in this pass, analysis and get the
// inputs and outputs is OK. // inputs and outputs is OK.
void DataFlowGraph::Build() { void DataFlowGraph::Build() {
inputs.clear(); inputs_.clear();
outputs.clear(); outputs_.clear();
std::unordered_set<Node *> ins; std::unordered_set<Node *> ins;
std::unordered_set<Node *> outs; std::unordered_set<Node *> outs;
for (auto &node : nodes.nodes()) { for (auto &node : nodes.nodes()) {
...@@ -42,18 +44,140 @@ void DataFlowGraph::Build() { ...@@ -42,18 +44,140 @@ void DataFlowGraph::Build() {
// similarly, the nodes that in outs but not in ins is the graphs' outputs // similarly, the nodes that in outs but not in ins is the graphs' outputs
for (auto *in : ins) { for (auto *in : ins) {
if (!outs.count(in)) { if (!outs.count(in)) {
inputs.push_back(in); inputs_.push_back(in);
} }
} }
for (auto *out : outs) { for (auto *out : outs) {
if (!outs.count(out)) { if (!ins.count(out)) {
outputs.push_back(out); outputs_.push_back(out);
} }
} }
Clean(); Clean();
} }
void DataFlowGraph::Build(const framework::proto::ProgramDesc &prog) {
// insert vars
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std::unordered_map<std::string, size_t> var2id;
auto &main_block = prog.blocks(framework::kRootBlockIndex);
for (int i = 0; i < main_block.vars_size(); i++) {
const auto &var = main_block.vars(i);
auto *v = nodes.Create(Node::Type::kValue);
v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id();
}
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std::unordered_set<Node *> unique_written_vars;
for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i);
auto *o = nodes.Create(Node::Type::kFunction);
o->SetName(op.type());
static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs
for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o);
o->inlinks.push_back(in);
}
}
for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = nodes.GetMutable(var2id[out_var.arguments(k)]);
if (unique_written_vars.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] =
out_alias->id(); // update variable's alias Node
LOG(INFO) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias;
}
out->inlinks.push_back(o);
o->outlinks.push_back(out);
unique_written_vars.insert(out);
}
}
}
// Analysis and extract the inputs and outputs of this graph.
Build();
}
void DataFlowGraph::Build(const framework::ir::Graph &graph) {
// Create nodes
std::unordered_map<ir_node_t *, Node *> ir_node_map;
for (auto *ir_node : graph.Nodes()) {
Node *x{nullptr};
if (ir_node->IsOp()) {
PADDLE_ENFORCE(ir_node->Op());
VLOG(4) << "get op " << ir_node << " " << ir_node->Name();
x = nodes.Create(Node::Type::kFunction);
x->attr("ir_node").Pointer() = ir_node;
PADDLE_ENFORCE(ir_node->Op()->Proto());
x->SetName(ir_node->Op()->Proto()->type());
x->SetPbMsg(ir_node->Op()->Proto()->SerializeAsString());
} else if (ir_node->IsVar()) {
// Not create a Node for IR ControlDepVar, considering Inference currently
// just used in single thread scenerio.
VLOG(4) << "get var " << ir_node->Name();
x = nodes.Create(Node::Type::kValue);
x->attr("ir_node").Pointer() = ir_node;
x->SetName(ir_node->Name());
// x->SetPbMsg(ir_node->Var()->Proto()->SerializeAsString());
} else {
PADDLE_THROW("Failed to create an Node from IR, unknown type");
}
ir_node_map.emplace(ir_node, x);
}
VLOG(4) << "finish creating Nodes";
VLOG(4) << "to create edge";
// Create links
for (auto *ir_node : graph.Nodes()) {
auto it = ir_node_map.find(ir_node);
// Skip ControlDepVar.
if (it == ir_node_map.end()) continue;
auto *node = it->second;
for (auto *x : ir_node->inputs) {
if (!ir_node_map.count(x)) continue;
node->inlinks.push_back(ir_node_map.at(x));
}
for (auto *x : ir_node->outputs) {
if (!ir_node_map.count(x)) continue;
node->outlinks.push_back(ir_node_map.at(x));
}
}
Build();
PADDLE_ENFORCE(!inputs_.empty(),
"Can't deduce any inputs from the graph, Is the graph empty?");
ir_graph = &graph;
VLOG(3) << "finished build from IR";
}
void DataFlowGraph::Clean() { void DataFlowGraph::Clean() {
for (auto &node : nodes.nodes()) { for (auto &node : nodes.nodes()) {
std::unordered_set<Node *> inlinks_set(node->inlinks.begin(), std::unordered_set<Node *> inlinks_set(node->inlinks.begin(),
...@@ -61,11 +185,9 @@ void DataFlowGraph::Clean() { ...@@ -61,11 +185,9 @@ void DataFlowGraph::Clean() {
std::unordered_set<Node *> outlinks_set(node->outlinks.begin(), std::unordered_set<Node *> outlinks_set(node->outlinks.begin(),
node->outlinks.end()); node->outlinks.end());
if (inlinks_set.size() < node->inlinks.size()) { if (inlinks_set.size() < node->inlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->inlinks.assign(inlinks_set.begin(), inlinks_set.end()); node->inlinks.assign(inlinks_set.begin(), inlinks_set.end());
} }
if (outlinks_set.size() < node->outlinks.size()) { if (outlinks_set.size() < node->outlinks.size()) {
LOG(INFO) << "Clean: node " << node->repr() << " prune duplicate inputs";
node->outlinks.assign(outlinks_set.begin(), outlinks_set.end()); node->outlinks.assign(outlinks_set.begin(), outlinks_set.end());
} }
} }
...@@ -112,10 +234,10 @@ GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator( ...@@ -112,10 +234,10 @@ GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
const std::vector<Node *> &source) const std::vector<Node *> &source)
: queue_(source.begin(), source.end()) {} : queue_(source.begin(), source.end()) {}
// GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator( GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
// GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept
// : queue_(std::move(other.queue_)), : queue_(std::move(other.queue_)),
// visited_(std::move(other.visited_)) {} visited_(std::move(other.visited_)) {}
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator( GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) const GraphTraits<DataFlowGraph>::NodesBFSIterator &other)
...@@ -159,7 +281,7 @@ bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==( ...@@ -159,7 +281,7 @@ bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
if (queue_.empty()) return other.queue_.empty(); if (queue_.empty()) return other.queue_.empty();
if ((!queue_.empty()) && (!other.queue_.empty())) { if ((!queue_.empty()) && (!other.queue_.empty())) {
return queue_.front() == other.queue_.front() && return queue_.front() == other.queue_.front() &&
visited_.size() == other.visited_.size(); // here need to check the visited_.size() == other.visited_.size();
// equality of queue and // equality of queue and
// visited. Just a light but week implementation. // visited. Just a light but week implementation.
} }
...@@ -174,10 +296,10 @@ GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator( ...@@ -174,10 +296,10 @@ GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
for (auto *x : source) stack_.push(x); for (auto *x : source) stack_.push(x);
} }
// GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator( GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
// GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept
// : stack_(std::move(other.stack_)), : stack_(std::move(other.stack_)),
// visited_(std::move(other.visited_)) {} visited_(std::move(other.visited_)) {}
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator( GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) const GraphTraits<DataFlowGraph>::NodesDFSIterator &other)
...@@ -339,7 +461,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT ...@@ -339,7 +461,7 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
std::vector<Node *> op_nodes; std::vector<Node *> op_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
if (node.type() == Node::Type::kValue || node.deleted()) { if (node.type() == Node::Type::kValue || node.deleted()) {
continue; continue;
} }
......
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/inference/analysis/graph_traits.h" #include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -41,19 +42,43 @@ namespace analysis { ...@@ -41,19 +42,43 @@ namespace analysis {
*/ */
struct DataFlowGraph { struct DataFlowGraph {
NodeMap nodes; NodeMap nodes;
std::vector<Node *> inputs; // inputs and outputs are deduced from the graph.
std::vector<Node *> outputs; // Used to interact with IR.
const framework::ir::Graph *ir_graph{nullptr};
// Extract inputs and outputs of the graph. // Extract inputs and outputs of the graph.
void Build(); void Build();
void Build(const framework::proto::ProgramDesc &prog);
// Build a graph from ir::Graph.
void Build(const framework::ir::Graph &graph);
// Get an attribute.
AnyAttr &Attr(const std::string &key) { return attrs_[key]; }
// Output a DOT graph file for debug. // Output a DOT graph file for debug.
std::string DotString() const; std::string DotString() const;
std::string HumanReadableInfo(bool show_values = true, std::string HumanReadableInfo(bool show_values = true,
bool show_functions = true) const; bool show_functions = true) const;
const std::vector<Node *> &inputs() const {
PADDLE_ENFORCE(!inputs_.empty(),
"No inputs are deduced, need to Build() first.");
return inputs_;
}
const std::vector<Node *> &outputs() const {
PADDLE_ENFORCE(!outputs_.empty(),
"No outputs are deduced, need to Build() first.");
return outputs_;
}
private: private:
mutable std::vector<Node *> inputs_;
mutable std::vector<Node *> outputs_;
std::unordered_map<std::string, AnyAttr> attrs_;
// Remove duplicate edges and so on. // Remove duplicate edges and so on.
void Clean(); void Clean();
}; };
...@@ -70,7 +95,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -70,7 +95,7 @@ struct GraphTraits<DataFlowGraph> {
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesBFSIterator() = default; NodesBFSIterator() = default;
explicit NodesBFSIterator(const std::vector<Node *> &source); explicit NodesBFSIterator(const std::vector<Node *> &source);
// NodesBFSIterator(NodesBFSIterator &&other) noexcept; NodesBFSIterator(NodesBFSIterator &&other) noexcept;
// NOTE Heavy to use. // NOTE Heavy to use.
NodesBFSIterator(const NodesBFSIterator &other); NodesBFSIterator(const NodesBFSIterator &other);
...@@ -93,8 +118,8 @@ struct GraphTraits<DataFlowGraph> { ...@@ -93,8 +118,8 @@ struct GraphTraits<DataFlowGraph> {
struct NodesDFSIterator struct NodesDFSIterator
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesDFSIterator() = default; NodesDFSIterator() = default;
explicit NodesDFSIterator(const std::vector<Node *> &source); NodesDFSIterator(const std::vector<Node *> &source);
// NodesDFSIterator(NodesDFSIterator &&other) noexcept; NodesDFSIterator(NodesDFSIterator &&other) noexcept;
NodesDFSIterator(const NodesDFSIterator &other); NodesDFSIterator(const NodesDFSIterator &other);
Node &operator*(); Node &operator*();
...@@ -116,7 +141,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -116,7 +141,7 @@ struct GraphTraits<DataFlowGraph> {
struct NodesTSIterator struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, Node *> { : public std::iterator<std::forward_iterator_tag, Node *> {
NodesTSIterator() = default; NodesTSIterator() = default;
explicit NodesTSIterator(const std::vector<Node *> &source); NodesTSIterator(const std::vector<Node *> &source);
NodesTSIterator(NodesTSIterator &&other) NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) { : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0; other.cursor_ = 0;
...@@ -138,7 +163,7 @@ struct GraphTraits<DataFlowGraph> { ...@@ -138,7 +163,7 @@ struct GraphTraits<DataFlowGraph> {
size_t cursor_{0}; size_t cursor_{0};
}; };
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {} explicit GraphTraits(const DataFlowGraph &graph) : graph_(graph) {}
// default use BFS to visit the nodes. // default use BFS to visit the nodes.
iterator_range<NodesBFSIterator> nodes() { iterator_range<NodesBFSIterator> nodes() {
...@@ -156,20 +181,20 @@ struct GraphTraits<DataFlowGraph> { ...@@ -156,20 +181,20 @@ struct GraphTraits<DataFlowGraph> {
private: private:
NodesBFSIterator nodes_bfs_begin() { NodesBFSIterator nodes_bfs_begin() {
return NodesBFSIterator(graph_->inputs); return NodesBFSIterator(graph_.inputs());
} }
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); } NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
NodesDFSIterator nodes_dfs_begin() { NodesDFSIterator nodes_dfs_begin() {
return NodesDFSIterator(graph_->inputs); return NodesDFSIterator(graph_.inputs());
} }
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); } NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_->inputs); } NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_.inputs()); }
NodesTSIterator nodes_ts_end() { return NodesTSIterator(); } NodesTSIterator nodes_ts_end() { return NodesTSIterator(); }
private: private:
DataFlowGraph *graph_; const DataFlowGraph &graph_;
}; };
// Extract the inputs and outputs of a graph. The inputs and outputs of a // Extract the inputs and outputs of a graph. The inputs and outputs of a
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle { namespace paddle {
...@@ -24,20 +25,18 @@ TEST(DataFlowGraph, BFS) { ...@@ -24,20 +25,18 @@ TEST(DataFlowGraph, BFS) {
auto dfg = ProgramDescToDFG(desc); auto dfg = ProgramDescToDFG(desc);
dfg.Build(); dfg.Build();
for (auto *in : dfg.inputs) { for (auto* in : dfg.inputs()) {
LOG(INFO) << "inputs: " << in->name() << " " LOG(INFO) << "inputs: " << in->name() << " "
<< static_cast<int>(in->type()); << static_cast<int>(in->type());
} }
for (auto *out : dfg.outputs) { for (auto* out : dfg.outputs()) {
LOG(INFO) << "outputs: " << out->name() << " " LOG(INFO) << "outputs: " << out->name() << " "
<< static_cast<int>(out->type()); << static_cast<int>(out->type());
} }
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes();
size_t count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto& node : GraphTraits<DataFlowGraph>(dfg).nodes()) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << node.name();
++count; ++count;
} }
ASSERT_EQ(count, dfg.nodes.size()); ASSERT_EQ(count, dfg.nodes.size());
...@@ -45,13 +44,11 @@ TEST(DataFlowGraph, BFS) { ...@@ -45,13 +44,11 @@ TEST(DataFlowGraph, BFS) {
TEST(DataFlowGraph, DFS) { TEST(DataFlowGraph, DFS) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__"); auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc); DataFlowGraph dfg;
dfg.Build(); dfg.Build(desc);
GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS();
size_t count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto& node : GraphTraits<DataFlowGraph>(dfg).nodes_in_DFS()) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << node.name();
++count; ++count;
} }
ASSERT_EQ(count, dfg.nodes.size()); ASSERT_EQ(count, dfg.nodes.size());
...@@ -74,21 +71,17 @@ TEST(DataFlowGraph, TS) { ...@@ -74,21 +71,17 @@ TEST(DataFlowGraph, TS) {
DataFlowGraph graph; DataFlowGraph graph;
for (int i = 0; i < 8; i++) { for (int i = 0; i < 8; i++) {
auto *node = graph.nodes.Create(Node::Type::kValue); auto* node = graph.nodes.Create(Node::Type::kValue);
node->SetName("node-" + std::to_string(i)); node->SetName("node-" + std::to_string(i));
} }
auto add_link = [&](int i, int j) { auto add_link = [&](int i, int j) {
Node *source = graph.nodes.GetMutable(i); Node* source = graph.nodes.GetMutable(i);
Node *target = graph.nodes.GetMutable(j); Node* target = graph.nodes.GetMutable(j);
target->inlinks.push_back(source); target->inlinks.push_back(source);
source->outlinks.push_back(target); source->outlinks.push_back(target);
}; };
graph.inputs.push_back(graph.nodes.GetMutable(0));
graph.inputs.push_back(graph.nodes.GetMutable(1));
graph.inputs.push_back(graph.nodes.GetMutable(2));
add_link(0, 4); add_link(0, 4);
add_link(0, 5); add_link(0, 5);
add_link(1, 6); add_link(1, 6);
...@@ -97,8 +90,9 @@ TEST(DataFlowGraph, TS) { ...@@ -97,8 +90,9 @@ TEST(DataFlowGraph, TS) {
add_link(4, 7); add_link(4, 7);
add_link(4, 3); add_link(4, 3);
add_link(7, 3); add_link(7, 3);
graph.Build();
auto its = GraphTraits<DataFlowGraph>(&graph).nodes_in_TS(); auto its = GraphTraits<DataFlowGraph>(graph).nodes_in_TS();
std::vector<int> sorted_ids; std::vector<int> sorted_ids;
for (auto it = its.begin(); it != its.end(); ++it) { for (auto it = its.begin(); it != its.end(); ++it) {
LOG(INFO) << it->name(); LOG(INFO) << it->name();
...@@ -122,6 +116,50 @@ TEST(DataFlowGraph, TS) { ...@@ -122,6 +116,50 @@ TEST(DataFlowGraph, TS) {
assert_positive_sequence_pair(4, 7); assert_positive_sequence_pair(4, 7);
} }
TEST(DataFlowGraph, Build_ProgramDesc) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
DataFlowGraph graph;
graph.Build(desc);
ASSERT_EQ(graph.nodes.size(), 38UL);
}
void SetOp(framework::ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Xs", outputs);
}
TEST(DataFlowGraph, Build_IR_Graph) {
framework::ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(framework::proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
DataFlowGraph graph;
framework::ir::Graph ir_graph(prog);
graph.Build(ir_graph);
ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size());
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -52,18 +52,15 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { ...@@ -52,18 +52,15 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) {
bool DataFlowGraphToFluidPass::Finalize() { return true; } bool DataFlowGraphToFluidPass::Finalize() { return true; }
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
LOG(INFO) << "graph.inputs " << graph->inputs.size(); // FilterRedundantOutputOfSubGraph(graph);
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
if (node.deleted()) continue; if (node.deleted()) continue;
switch (node.type()) { switch (node.type()) {
case Node::Type::kFunction: { case Node::Type::kFunction: {
LOG(INFO) << "add function " << node.repr();
AddFluidOp(&node); AddFluidOp(&node);
} break; } break;
case Node::Type::kFunctionBlock: { case Node::Type::kFunctionBlock: {
LOG(INFO) << "add engine op " << node.repr() << " , "
<< static_cast<FunctionBlock *>(&node)->subgraph.size();
AddEngineOp(&node); AddEngineOp(&node);
} break; } break;
default: default:
...@@ -75,15 +72,27 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { ...@@ -75,15 +72,27 @@ void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) {
} }
void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc()); PADDLE_ENFORCE(node);
PADDLE_ENFORCE(node->IsFunction());
PADDLE_ENFORCE(node->pb_desc() || !node->pb_msg().empty(),
"node has invalid protobuf repr.");
// currently only the main block is analyzed. // currently only the main block is analyzed.
PADDLE_ENFORCE(desc_);
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex); auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto *op = main_block->add_ops(); auto *op = main_block->add_ops();
*op = *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase. if (node->pb_desc()) {
// The inputs and outputs of the existing ops are not changed by tensorrt auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc());
// subgraph pass. *op =
// NOTE It might be changed by other passes in the long run. *ori_op; // copy the attributes, by default, these will not be changed
// by analysis phrase.
// The inputs and outputs of the existing ops are not changed by tensorrt
// subgraph pass.
// NOTE It might be changed by other passes in the long run.
} else {
op->ParseFromString(node->pb_msg());
}
} }
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph, void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
...@@ -220,10 +229,9 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) { ...@@ -220,10 +229,9 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
framework::BlockDesc block_desc(nullptr, &proto); framework::BlockDesc block_desc(nullptr, &proto);
block_desc.Proto()->set_parent_idx(-1); block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0); block_desc.Proto()->set_idx(0);
LOG(INFO) << "origin variable size: " VLOG(4) << "origin variable size: "
<< argument_->origin_program_desc->blocks(0).vars().size(); << argument_->origin_program_desc->blocks(0).vars().size();
LOG(INFO) << "transformed variable size: " VLOG(4) << "transformed variable size: " << block_desc.Proto()->vars().size();
<< block_desc.Proto()->vars().size();
// copy ops. // copy ops.
for (auto *node : block_node->subgraph) { for (auto *node : block_node->subgraph) {
...@@ -257,7 +265,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -257,7 +265,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const { Pass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, FLAGS_IA_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger")); "data_flow_graph_to_fluid_graphviz_debugger"));
} }
......
...@@ -29,7 +29,7 @@ void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) { ...@@ -29,7 +29,7 @@ void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png"; auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png";
std::string message; std::string message;
LOG(INFO) << "draw to " << png_path; VLOG(3) << "draw to " << png_path;
ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message); ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message);
} }
......
...@@ -52,72 +52,7 @@ bool FluidToDataFlowGraphPass::Finalize() { return true; } ...@@ -52,72 +52,7 @@ bool FluidToDataFlowGraphPass::Finalize() { return true; }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) { void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph); PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(desc_); PADDLE_ENFORCE(desc_);
// insert vars graph->Build(*desc_);
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std::unordered_map<std::string, size_t> var2id;
auto &main_block = desc_->blocks(framework::kRootBlockIndex);
for (int i = 0; i < main_block.vars_size(); i++) {
const auto &var = main_block.vars(i);
auto *v = graph->nodes.Create(Node::Type::kValue);
v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id();
}
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std::unordered_set<Node *> unique_written_vars;
for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i);
auto *o = graph->nodes.Create(Node::Type::kFunction);
o->SetName(op.type());
static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs
for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o);
o->inlinks.push_back(in);
}
}
for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
if (unique_written_vars.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = graph->nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] =
out_alias->id(); // update variable's alias Node
LOG(INFO) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias;
}
out->inlinks.push_back(o);
o->outlinks.push_back(out);
unique_written_vars.insert(out);
}
}
}
// Analysis and extract the inputs and outputs of this graph.
graph->Build();
} }
namespace { namespace {
...@@ -133,7 +68,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass { ...@@ -133,7 +68,7 @@ class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const { Pass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config( return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_inference_analysis_graphviz_log_root, "fluid-to-dfg-debuger")); FLAGS_IA_graphviz_log_root, "fluid-to-dfg-debuger"));
} }
} // namespace analysis } // namespace analysis
......
...@@ -30,7 +30,7 @@ TEST(FluidToDataFlowGraphPass, Test) { ...@@ -30,7 +30,7 @@ TEST(FluidToDataFlowGraphPass, Test) {
ASSERT_EQ(argument.main_dfg->nodes.size(), 38UL); ASSERT_EQ(argument.main_dfg->nodes.size(), 38UL);
pass.Finalize(); pass.Finalize();
ASSERT_FALSE(argument.main_dfg->DotString().empty()); ASSERT_FALSE(argument.main_dfg->DotString().empty());
EXPECT_FALSE(argument.main_dfg->inputs.empty()); EXPECT_FALSE(argument.main_dfg->inputs().empty());
} }
} // namespace analysis } // namespace analysis
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/pass.h"
namespace paddle {
namespace inference {
namespace analysis {
class FluidToIrPass final : public DataFlowGraphPass {
public:
FluidToIrPass() = default;
bool Initialize(Argument *argument) override {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called";
}
// set fluid model program path
if (!argument->fluid_model_program_path) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
argument->fluid_model_program_path.reset(
new std::string(*argument->fluid_model_dir + "/__model__"));
}
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
// Create main data flow graph.
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
}
// Persist the ProgramDesc in graph's attribute. The IR graph just keep the
// address, will segfault if the original ProgramDesc destroys.
auto &ir_program_p = argument->main_dfg->Attr("ir_program_desc").Pointer();
ir_program_p = new framework::ProgramDesc(program);
argument_ = argument;
return true;
}
bool Finalize() override { return true; }
void Run(DataFlowGraph *graph) override {
// Call all the IR Passes
IRPassManager ir_passes(*static_cast<framework::ProgramDesc *>(
argument_->main_dfg->Attr("ir_program_desc").Pointer()));
ir_passes.Apply(std::vector<std::string>(
{// Manual update the passes here.
"graph_viz_pass", "infer_clean_graph_pass", "graph_viz_pass",
"fc_fuse_pass", "graph_viz_pass"}));
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
// PADDLE_ENFORCE(argument_->main_dfg->IsFullyConnected());
}
std::string repr() const override { return "fluid-to-ir-pass"; }
private:
Argument *argument_{nullptr};
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST(FluidToIrPass, Test) {
FluidToIrPass pass;
Argument argument(FLAGS_inference_model_dir);
pass.Initialize(&argument);
pass.Run(argument.main_dfg.get());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
USE_PASS(fc_fuse_pass);
USE_PASS(graph_viz_pass);
USE_PASS(infer_clean_graph_pass);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <sys/stat.h>
#include <cstdio> #include <cstdio>
#include <fstream> #include <fstream>
#include <string> #include <string>
...@@ -151,6 +152,23 @@ static framework::proto::ProgramDesc LoadProgramDesc( ...@@ -151,6 +152,23 @@ static framework::proto::ProgramDesc LoadProgramDesc(
return program_desc; return program_desc;
} }
static bool FileExists(const std::string &filepath) {
std::ifstream file(filepath);
bool exists = file.is_open();
file.close();
return exists;
}
static bool PathExists(const std::string &path) {
struct stat statbuf;
if (stat(path.c_str(), &statbuf) != -1) {
if (S_ISDIR(statbuf.st_mode)) {
return true;
}
}
return false;
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
namespace paddle {
namespace inference {
namespace analysis {
IRPassManager::IRPassManager(const ProgramDesc& program) {
graph_.reset(new framework::ir::Graph(program));
}
void IRPassManager::Apply(const std::vector<std::string>& passes) {
graph_->Set("graph_viz_path", new std::string("./1.dot"));
// Apply all the passes
std::string pre_pass;
for (const std::string& pass_name : passes) {
LOG(WARNING) << "Running IR pass [" << pass_name << "]";
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") {
std::string dot_file_path =
"ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot";
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
}
graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name;
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* This file defines IRPassManager, it helps control the passes in IR. Inference
* phrase will load the model program and parameters from disk, that is quite
* different from the training phase.
* This manager will control the Passes and make the passes in IR work smoothly
* for inference.
*/
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace inference {
namespace analysis {
using framework::ProgramDesc;
class IRPassManager final {
public:
IRPassManager(const ProgramDesc& program);
void Apply(const std::vector<std::string>& passes);
framework::ir::Graph& graph() const { return *graph_; }
private:
std::unique_ptr<framework::ir::Graph> graph_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -35,19 +35,21 @@ void ModelStorePass::Run(DataFlowGraph *x) { ...@@ -35,19 +35,21 @@ void ModelStorePass::Run(DataFlowGraph *x) {
std::stringstream ss; std::stringstream ss;
// NOTE these commands only works on linux. // NOTE these commands only works on linux.
ss << "mkdir -p " << *argument_->model_output_store_path; ss << "mkdir -p " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str(); VLOG(3) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
ss.str(""); ss.str("");
ss << "cp " << *argument_->fluid_model_dir << "/*" ss << "cp " << *argument_->fluid_model_dir << "/*"
<< " " << *argument_->model_output_store_path; << " " << *argument_->model_output_store_path;
LOG(INFO) << "run command: " << ss.str(); VLOG(3) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0); PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
// Store program // Store program
PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc, PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc,
"program desc is not transformed, should call " "program desc is not transformed, should call "
"DataFlowGraphToFluidPass first."); "DataFlowGraphToFluidPass first.");
VLOG(3) << "store analyzed program to "
<< *argument_->model_output_store_path;
const std::string program_output_path = const std::string program_output_path =
*argument_->model_output_store_path + "/__model__"; *argument_->model_output_store_path + "/__model__";
std::ofstream file(program_output_path, std::ios::binary); std::ofstream file(program_output_path, std::ios::binary);
...@@ -58,6 +60,8 @@ void ModelStorePass::Run(DataFlowGraph *x) { ...@@ -58,6 +60,8 @@ void ModelStorePass::Run(DataFlowGraph *x) {
file.write(serialized_message.c_str(), serialized_message.size()); file.write(serialized_message.c_str(), serialized_message.size());
} }
bool ModelStorePass::Finalize() { return true; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -44,6 +44,8 @@ class ModelStorePass : public DataFlowGraphPass { ...@@ -44,6 +44,8 @@ class ModelStorePass : public DataFlowGraphPass {
model in the disk, and that model can be reloaded for prediction again.)DD"; model in the disk, and that model can be reloaded for prediction again.)DD";
} }
bool Finalize() override;
private: private:
Argument* argument_{nullptr}; Argument* argument_{nullptr};
}; };
......
...@@ -30,7 +30,7 @@ TEST(DFG_StorePass, test) { ...@@ -30,7 +30,7 @@ TEST(DFG_StorePass, test) {
argument.model_output_store_path.reset( argument.model_output_store_path.reset(
new std::string("./_dfg_store_pass_tmp")); new std::string("./_dfg_store_pass_tmp"));
// disable storage in alalyzer // disable storage in alalyzer
FLAGS_inference_analysis_output_storage_path = ""; FLAGS_IA_output_storage_path = "";
analyzer.Run(&argument); analyzer.Run(&argument);
ModelStorePass pass; ModelStorePass pass;
......
...@@ -38,7 +38,7 @@ namespace analysis { ...@@ -38,7 +38,7 @@ namespace analysis {
class NodeMap; class NodeMap;
// A helper class to maintain the status from Pass. // A helper class to maintain the status from Pass.
struct NodeAttr { struct AnyAttr {
using any_t = using any_t =
boost::variant<bool, float, int32_t, int64_t, void *, std::string>; boost::variant<bool, float, int32_t, int64_t, void *, std::string>;
// NOTE T should be a primary type or a struct combined by several primary // NOTE T should be a primary type or a struct combined by several primary
...@@ -54,10 +54,9 @@ struct NodeAttr { ...@@ -54,10 +54,9 @@ struct NodeAttr {
void *&Pointer() { return As<void *>(); } void *&Pointer() { return As<void *>(); }
std::string &String() { return As<std::string>(); } std::string &String() { return As<std::string>(); }
private:
template <typename T> template <typename T>
T &As() { T &As() {
if (type_index_ == typeid(NodeAttr)) { if (type_index_ == typeid(AnyAttr)) {
type_index_ = typeid(T); type_index_ = typeid(T);
any_data_ = T(); any_data_ = T();
} else { } else {
...@@ -68,7 +67,7 @@ struct NodeAttr { ...@@ -68,7 +67,7 @@ struct NodeAttr {
private: private:
any_t any_data_; any_t any_data_;
std::type_index type_index_{typeid(NodeAttr)}; std::type_index type_index_{typeid(AnyAttr)};
}; };
/* /*
...@@ -105,7 +104,7 @@ class Node { ...@@ -105,7 +104,7 @@ class Node {
// Get an additional attribute and convert it to T data type. NOTE this will // Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists. // silently create a new attribute if not exists.
NodeAttr &attr(const std::string &name) const { return attrs_[name]; } AnyAttr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; } int id() const { return id_; }
...@@ -150,7 +149,7 @@ class Node { ...@@ -150,7 +149,7 @@ class Node {
Type type_{Type::kNone}; Type type_{Type::kNone};
// Mark this node is deleted by some pass. // Mark this node is deleted by some pass.
bool deleted_{false}; bool deleted_{false};
mutable std::unordered_map<std::string, NodeAttr> attrs_; mutable std::unordered_map<std::string, AnyAttr> attrs_;
}; };
class Function; class Function;
......
...@@ -21,19 +21,19 @@ namespace inference { ...@@ -21,19 +21,19 @@ namespace inference {
namespace analysis { namespace analysis {
TEST(NodeAttr, bool) { TEST(NodeAttr, bool) {
NodeAttr x; AnyAttr x;
x.Bool() = true; x.Bool() = true;
ASSERT_EQ(x.Bool(), true); ASSERT_EQ(x.Bool(), true);
} }
TEST(NodeAttr, int32) { TEST(NodeAttr, int32) {
NodeAttr x; AnyAttr x;
x.Int32() = 32; x.Int32() = 32;
ASSERT_EQ(x.Int32(), 32); ASSERT_EQ(x.Int32(), 32);
} }
TEST(NodeAttr, string) { TEST(NodeAttr, string) {
NodeAttr x; AnyAttr x;
x.String() = "Hello"; x.String() = "Hello";
ASSERT_EQ(x.String(), "Hello"); ASSERT_EQ(x.String(), "Hello");
} }
......
...@@ -63,7 +63,7 @@ class Pass { ...@@ -63,7 +63,7 @@ class Pass {
// Human-readable short representation. // Human-readable short representation.
virtual std::string repr() const = 0; virtual std::string repr() const = 0;
// Human-readable long description. // Human-readable long description.
virtual std::string description() const = 0; virtual std::string description() const { return "No DOC"; }
}; };
// NodePass process on any Node types. // NodePass process on any Node types.
......
...@@ -22,7 +22,7 @@ namespace analysis { ...@@ -22,7 +22,7 @@ namespace analysis {
bool PassManager::Initialize(Argument* argument) { bool PassManager::Initialize(Argument* argument) {
argument_ = argument; argument_ = argument;
for (auto& pass : data_) { for (auto& pass : data_) {
LOG(INFO) << "Initializing pass " << pass->repr(); LOG(WARNING) << "Initializing pass [" << pass->repr() << "]";
if (!pass->Initialize(argument)) { if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]"; LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false; return false;
...@@ -33,8 +33,9 @@ bool PassManager::Initialize(Argument* argument) { ...@@ -33,8 +33,9 @@ bool PassManager::Initialize(Argument* argument) {
void DfgPassManager::RunAll() { void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
LOG(INFO) << "Total " << data_.size() << " passes";
for (auto& pass : data_) { for (auto& pass : data_) {
VLOG(4) << "Running pass [" << pass->repr() << "]"; LOG(WARNING) << "Running pass [" << pass->repr() << "]";
pass->Run(argument_->main_dfg.get()); pass->Run(argument_->main_dfg.get());
} }
} }
...@@ -42,8 +43,7 @@ void DfgPassManager::RunAll() { ...@@ -42,8 +43,7 @@ void DfgPassManager::RunAll() {
void NodePassManager::RunAll() { void NodePassManager::RunAll() {
PADDLE_ENFORCE(argument_); PADDLE_ENFORCE(argument_);
PADDLE_ENFORCE(argument_->main_dfg.get()); PADDLE_ENFORCE(argument_->main_dfg.get());
auto trait = auto trait = GraphTraits<DataFlowGraph>(*argument_->main_dfg).nodes_in_DFS();
GraphTraits<DataFlowGraph>(argument_->main_dfg.get()).nodes_in_DFS();
for (auto& node : trait) { for (auto& node : trait) {
for (auto& pass : data_) { for (auto& pass : data_) {
pass->Run(&node); pass->Run(&node);
......
...@@ -34,7 +34,7 @@ inline void MarkOutLinksInSubGraph(const Function *func) { ...@@ -34,7 +34,7 @@ inline void MarkOutLinksInSubGraph(const Function *func) {
} }
void SubGraphSplitter::MarkNodesInsideSubGraph() { void SubGraphSplitter::MarkNodesInsideSubGraph() {
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes()) {
if (node_inside_subgraph_teller_(&node)) { if (node_inside_subgraph_teller_(&node)) {
node.attr(kMarkerAttrName).Bool() = true; node.attr(kMarkerAttrName).Bool() = true;
if (node.type() == Node::Type::kFunction) { if (node.type() == Node::Type::kFunction) {
...@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { ...@@ -76,7 +76,7 @@ void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(graph_).nodes_in_TS()) { for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) {
if (node.attr(kMarkerAttrName).Bool()) { if (node.attr(kMarkerAttrName).Bool()) {
marked_nodes.push_back(&node); marked_nodes.push_back(&node);
} }
......
...@@ -69,8 +69,8 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass { ...@@ -69,8 +69,8 @@ class DfgDebuggerPass : public DFG_GraphvizDrawPass {
}; };
Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const { Pass *TensorRTSubgraphNodeMarkPass::CreateGraphvizDebugerPass() const {
DFG_GraphvizDrawPass::Config config( DFG_GraphvizDrawPass::Config config(FLAGS_IA_graphviz_log_root,
FLAGS_inference_analysis_graphviz_log_root, "tensorrt_marked_node"); "tensorrt_marked_node");
return new DfgDebuggerPass(config); return new DfgDebuggerPass(config);
} }
bool TensorRTSubgraphNodeMarkPass::Finalize() { return true; } bool TensorRTSubgraphNodeMarkPass::Finalize() { return true; }
......
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
void CompareTensorRTWithFluid(bool enable_tensorrt) { void CompareTensorRTWithFluid(bool enable_tensorrt) {
FLAGS_inference_analysis_enable_tensorrt_subgraph_engine = enable_tensorrt; FLAGS_IA_enable_tensorrt_subgraph_engine = enable_tensorrt;
//# 1. Create PaddlePredictor with a config. //# 1. Create PaddlePredictor with a config.
NativeConfig config0; NativeConfig config0;
......
...@@ -9,7 +9,6 @@ function(op_library TARGET) ...@@ -9,7 +9,6 @@ function(op_library TARGET)
# op_library is a function to create op library. The interface is same as # op_library is a function to create op library. The interface is same as
# cc_library. But it handle split GPU/CPU code and link some common library # cc_library. But it handle split GPU/CPU code and link some common library
# for ops. # for ops.
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
set(cc_srcs) set(cc_srcs)
set(cu_srcs) set(cu_srcs)
set(hip_cu_srcs) set(hip_cu_srcs)
...@@ -92,6 +91,7 @@ function(op_library TARGET) ...@@ -92,6 +91,7 @@ function(op_library TARGET)
endif() endif()
endforeach() endforeach()
endif(WIN32) endif(WIN32)
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
list(LENGTH op_library_DEPS op_library_DEPS_len) list(LENGTH op_library_DEPS op_library_DEPS_len)
if (${op_library_DEPS_len} GREATER 0) if (${op_library_DEPS_len} GREATER 0)
......
...@@ -126,6 +126,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -126,6 +126,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
pipeline); pipeline);
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_primitive_desc();
auto bias_pd = conv_pd_->bias_primitive_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline);
}
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution( std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p, std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p, std::shared_ptr<mkldnn::memory> weights_memory_p,
...@@ -147,6 +156,28 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -147,6 +156,28 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
return conv_p; return conv_p;
} }
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p,
std::shared_ptr<mkldnn::memory> bias_memory_p,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
auto prim_key = key_ + "@conv_p";
auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (conv_p == nullptr) {
conv_p = std::make_shared<mkldnn::convolution_forward>(
*conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
*(bias_memory_p.get()), *(dst_memory_p.get()));
dev_ctx_.SetBlob(prim_key, conv_p);
} else {
is_reusing_ = true;
}
return conv_p;
}
std::shared_ptr<mkldnn::convolution_backward_weights> std::shared_ptr<mkldnn::convolution_backward_weights>
AcquireConvolutionBackwardWeights( AcquireConvolutionBackwardWeights(
std::shared_ptr<mkldnn::memory> src_memory_p, std::shared_ptr<mkldnn::memory> src_memory_p,
...@@ -229,6 +260,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -229,6 +260,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("Input"); auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter"); auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output"); auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
...@@ -237,6 +269,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -237,6 +269,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef, filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); "Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
"Wrong layout/format set for Bias tensor");
PADDLE_ENFORCE(bias->dims().size() == 1,
"Bias must only have 1 dimension, i.e. X");
}
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
...@@ -253,11 +296,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -253,11 +296,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
...@@ -288,13 +326,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -288,13 +326,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd = std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, if (bias) {
mkldnn_engine); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine);
} else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
paddings, mkldnn_engine);
}
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -315,8 +363,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -315,8 +363,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
// create convolution op primitive // create convolution op primitive
auto conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p, std::shared_ptr<mkldnn::convolution_forward> conv_p;
dst_memory_p); if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
bias_memory_p, dst_memory_p);
} else {
conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
dst_memory_p);
}
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
...@@ -346,6 +408,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -346,6 +408,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>( return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd); p_conv_pd);
} }
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
auto p_conv_pd =
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
}; };
template <typename T> template <typename T>
......
...@@ -37,6 +37,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -37,6 +37,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
...@@ -57,7 +58,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -57,7 +58,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
"The number of input channels should be equal to filter " "The number of input channels should be equal to filter "
"channels * groups."); "channels * groups.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter_dims[0] % groups, 0, filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups."); "The number of output channels should be divided by groups.");
...@@ -122,6 +122,11 @@ void Conv2DOpMaker::Make() { ...@@ -122,6 +122,11 @@ void Conv2DOpMaker::Make() {
"H is the height of the filter, and W is the width of the filter. " "H is the height of the filter, and W is the width of the filter. "
"If the groups attribute is greater than 1, C equals the number of " "If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddInput("Bias",
"(Tensor) Bias to be added to each output of filter application."
"The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.")
.AsDispensable();
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator. " "(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.") "The format of output tensor is also NCHW.")
......
...@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -130,12 +130,13 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
checkpoint_notify_id != -1, checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); "when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); // TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name; << out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope); executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
return true; return true;
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,6 +24,37 @@ struct MulFunctor { ...@@ -23,6 +24,37 @@ struct MulFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; } inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
}; };
template <typename DeviceContext, typename T>
void default_elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
MulFunctor<T>(), z);
}
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_floating_point<T>::value &&
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
blas.VMUL(x->numel(), x->data<T>(), y->data<T>(),
z->mutable_data<T>(ctx.GetPlace()));
}
template <typename DeviceContext, typename T>
typename std::enable_if<
!std::is_floating_point<T>::value ||
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_mul(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseMulKernel : public framework::OpKernel<T> { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
...@@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> { ...@@ -33,9 +65,11 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
auto* y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* z = ctx.Output<Tensor>("Out"); auto* z = ctx.Output<Tensor>("Out");
z->mutable_data<T>(ctx.GetPlace()); z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis"); if (x->numel() == y->numel()) {
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx, x, y, axis, elementwise_mul<DeviceContext, T>(ctx, x, y, z);
MulFunctor<T>(), z); } else {
default_elementwise_mul<DeviceContext, T>(ctx, x, y, z);
}
} }
}; };
......
...@@ -35,9 +35,14 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -35,9 +35,14 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias"); auto bias_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim]."); if (bias_dims.size() == 2) {
PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1], PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
"The shape of Bias must be [1, dim]."); PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
"The shape of Bias must be [1, dim].");
} else if (bias_dims.size() == 1) {
PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
"The shape of Bias must be [1, dim].");
}
} }
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4, PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
"Fully Connected input should be 2-D or 4-D tensor."); "Fully Connected input should be 2-D or 4-D tensor.");
......
...@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -92,6 +92,7 @@ class LoadOp : public framework::OperatorBase {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx); framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
} }
}; };
......
...@@ -134,6 +134,9 @@ class Blas { ...@@ -134,6 +134,9 @@ class Blas {
template <typename T> template <typename T>
void VADD(int n, const T* x, const T* y, T* z) const; void VADD(int n, const T* x, const T* y, T* z) const;
template <typename T>
void VMUL(int n, const T* x, const T* y, T* z) const;
template <typename T> template <typename T>
void VCOPY(int n, const T* x, T* y) const; void VCOPY(int n, const T* x, T* y) const;
...@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> { ...@@ -202,6 +205,11 @@ class BlasT : private Blas<DeviceContext> {
Base()->template VADD<T>(args...); Base()->template VADD<T>(args...);
} }
template <typename... ARGS>
void VMUL(ARGS... args) const {
Base()->template VMUL<T>(args...);
}
template <typename... ARGS> template <typename... ARGS>
void VCOPY(ARGS... args) const { void VCOPY(ARGS... args) const {
Base()->template VCOPY<T>(args...); Base()->template VCOPY<T>(args...);
......
...@@ -82,6 +82,11 @@ struct CBlas<float> { ...@@ -82,6 +82,11 @@ struct CBlas<float> {
static void VADD(ARGS... args) { static void VADD(ARGS... args) {
platform::dynload::vsAdd(args...); platform::dynload::vsAdd(args...);
} }
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vsMul(args...);
}
}; };
template <> template <>
...@@ -142,6 +147,11 @@ struct CBlas<double> { ...@@ -142,6 +147,11 @@ struct CBlas<double> {
static void VADD(ARGS... args) { static void VADD(ARGS... args) {
platform::dynload::vdAdd(args...); platform::dynload::vdAdd(args...);
} }
template <typename... ARGS>
static void VMUL(ARGS... args) {
platform::dynload::vdMul(args...);
}
}; };
#else #else
...@@ -199,6 +209,7 @@ struct CBlas<platform::float16> { ...@@ -199,6 +209,7 @@ struct CBlas<platform::float16> {
static void SMM_GEMM(...) { static void SMM_GEMM(...) {
PADDLE_THROW("float16 SMM_GEMM not supported on CPU"); PADDLE_THROW("float16 SMM_GEMM not supported on CPU");
} }
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
static void GEMM_BATCH(...) { static void GEMM_BATCH(...) {
PADDLE_THROW("float16 GEMM_BATCH not supported on CPU"); PADDLE_THROW("float16 GEMM_BATCH not supported on CPU");
...@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y, ...@@ -374,6 +385,20 @@ void Blas<platform::CPUDeviceContext>::VADD(int n, const T *x, const T *y,
#endif #endif
} }
template <>
template <typename T>
void Blas<platform::CPUDeviceContext>::VMUL(int n, const T *x, const T *y,
T *z) const {
#ifdef PADDLE_WITH_MKLML
CBlas<T>::VMUL(n, x, y, z);
#else
// try to find if openblas support vmul
for (int i = 0; i < n; ++i) {
z[i] = x[i] * y[i];
}
#endif
}
template <> template <>
template <typename T> template <typename T>
void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha, void Blas<platform::CPUDeviceContext>::GEMV(bool trans_a, int M, int N, T alpha,
......
...@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase { ...@@ -142,6 +142,8 @@ class SaveOp : public framework::OperatorBase {
std::string filename = lt_var->data(); std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename; VLOG(4) << "SaveSelectedRows get File name: " << filename;
MkDirRecursively(DirName(filename).c_str());
auto &selectedRows = var->Get<framework::SelectedRows>(); auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool // get device context from pool
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <cublasXt.h> #include <cublasXt.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda.h> #include <cuda.h>
#include <dlfcn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <type_traits> #include <type_traits>
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <cudnn.h> #include <cudnn.h>
#include <dlfcn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -17,10 +17,10 @@ limitations under the License. */ ...@@ -17,10 +17,10 @@ limitations under the License. */
#include <cuda.h> #include <cuda.h>
#include <cupti.h> #include <cupti.h>
#include <dlfcn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <curand.h> #include <curand.h>
#include <dlfcn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
......
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <dlfcn.h>
#include <mkl.h> #include <mkl.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
...@@ -49,25 +49,27 @@ extern void* mklml_dso_handle; ...@@ -49,25 +49,27 @@ extern void* mklml_dso_handle;
#define MKLML_ROUTINE_EACH(__macro) \ #define MKLML_ROUTINE_EACH(__macro) \
__macro(cblas_sgemm); \ __macro(cblas_sgemm); \
__macro(cblas_saxpy); \
__macro(cblas_scopy); \
__macro(cblas_sgemv); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm); \ __macro(cblas_dgemm); \
__macro(cblas_saxpy); \
__macro(cblas_daxpy); \ __macro(cblas_daxpy); \
__macro(cblas_scopy); \
__macro(cblas_dcopy); \ __macro(cblas_dcopy); \
__macro(cblas_sgemv); \
__macro(cblas_dgemv); \ __macro(cblas_dgemv); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(cblas_sgemm_alloc); \ __macro(cblas_sgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_alloc); \ __macro(cblas_dgemm_alloc); \
__macro(cblas_sgemm_pack); \
__macro(cblas_dgemm_pack); \ __macro(cblas_dgemm_pack); \
__macro(cblas_sgemm_compute); \
__macro(cblas_dgemm_compute); \ __macro(cblas_dgemm_compute); \
__macro(cblas_sgemm_free); \
__macro(cblas_dgemm_free); \ __macro(cblas_dgemm_free); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(vsMul); \
__macro(vdMul); \
__macro(MKL_Set_Num_Threads) __macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP); MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
......
...@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <dlfcn.h>
#include <nccl.h> #include <nccl.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
...@@ -14,10 +14,9 @@ limitations under the License. */ ...@@ -14,10 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <dlfcn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
#include "warpctc/include/ctc.h" #include "warpctc/include/ctc.h"
namespace paddle { namespace paddle {
......
...@@ -14,9 +14,6 @@ limitations under the License. */ ...@@ -14,9 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#include <dlfcn.h> // for dladdr
#include <execinfo.h> // for backtrace
#ifdef __GNUC__ #ifdef __GNUC__
#include <cxxabi.h> // for __cxa_demangle #include <cxxabi.h> // for __cxa_demangle
#endif // __GNUC__ #endif // __GNUC__
...@@ -37,6 +34,7 @@ limitations under the License. */ ...@@ -37,6 +34,7 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/to_string.h" #include "paddle/fluid/string/to_string.h"
...@@ -75,7 +73,7 @@ struct EnforceNotMet : public std::exception { ...@@ -75,7 +73,7 @@ struct EnforceNotMet : public std::exception {
sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl;
sout << "PaddlePaddle Call Stacks: " << std::endl; sout << "PaddlePaddle Call Stacks: " << std::endl;
#if !defined(_WIN32)
void* call_stack[TRACE_STACK_LIMIT]; void* call_stack[TRACE_STACK_LIMIT];
auto size = backtrace(call_stack, TRACE_STACK_LIMIT); auto size = backtrace(call_stack, TRACE_STACK_LIMIT);
auto symbols = backtrace_symbols(call_stack, size); auto symbols = backtrace_symbols(call_stack, size);
...@@ -95,6 +93,9 @@ struct EnforceNotMet : public std::exception { ...@@ -95,6 +93,9 @@ struct EnforceNotMet : public std::exception {
} }
} }
free(symbols); free(symbols);
#else
sout << "Windows not support stack backtrace yet.";
#endif
err_str_ = sout.str(); err_str_ = sout.str();
} }
} }
......
...@@ -125,6 +125,11 @@ class MKLDNNHandler { ...@@ -125,6 +125,11 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p"); return this->AcquireMemory(md, ptr, "@user_weights_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p"); return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdexcept>
#include <string>
#if !defined(_WIN32)
#include <dlfcn.h> // for dladdr
#include <execinfo.h> // for backtrace
#else
#include <Shlwapi.h>
#include <Windows.h>
static void* dlsym(void* handle, const char* symbol_name) {
FARPROC found_symbol;
found_symbol = GetProcAddress((HMODULE)handle, symbol_name);
if (found_symbol == NULL) {
throw std::runtime_error(std::string(symbol_name) + " not found.");
}
return reinterpret_cast<void*>(found_symbol);
}
#endif
set(PYBIND_DEPS pybind python proto_desc memory executor prune profiler feed_fetch_method
)
if(NOT WIN32)
list(APPEND PYBIND_DEPS parallel_executor)
endif()
if(WITH_PYTHON) if(WITH_PYTHON)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
hip_library(paddle_pybind SHARED hip_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc
DEPS pybind python proto_desc memory executor prune profiler feed_fetch_method DEPS ${PYBIND_DEPS}
parallel_executor
${GLOB_OP_LIB}) ${GLOB_OP_LIB})
else() else()
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc
DEPS pybind python proto_desc memory executor prune profiler feed_fetch_method DEPS ${PYBIND_DEPS}
parallel_executor
${GLOB_OP_LIB}) ${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID) if(NOT APPLE AND NOT ANDROID AND NOT WIN32)
target_link_libraries(paddle_pybind rt) target_link_libraries(paddle_pybind rt)
endif(NOT APPLE AND NOT ANDROID) endif(NOT APPLE AND NOT ANDROID AND NOT WIN32)
endif(WITH_AMD_GPU) endif(WITH_AMD_GPU)
cc_test(tensor_py_test SRCS tensor_py_test.cc DEPS python) cc_test(tensor_py_test SRCS tensor_py_test.cc DEPS python)
......
...@@ -1363,6 +1363,13 @@ class Program(object): ...@@ -1363,6 +1363,13 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = [] self._op_role_var = []
# for distribute
self._is_distributed = False
self._is_chief = False
self._slice_vars_and_attrs = []
self._endpoints = []
self._distributed_lookup_table = None
@property @property
def op_role(self): def op_role(self):
""" """
......
...@@ -372,6 +372,7 @@ def load_vars(executor, ...@@ -372,6 +372,7 @@ def load_vars(executor,
load_vars( load_vars(
executor, executor,
dirname=dirname, dirname=dirname,
main_program=main_program,
vars=list(filter(predicate, main_program.list_vars())), vars=list(filter(predicate, main_program.list_vars())),
filename=filename) filename=filename)
else: else:
...@@ -403,9 +404,12 @@ def load_vars(executor, ...@@ -403,9 +404,12 @@ def load_vars(executor,
inputs={}, inputs={},
outputs={"Out": load_var_list}, outputs={"Out": load_var_list},
attrs={'file_path': os.path.join(dirname, filename)}) attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog) executor.run(load_prog)
# load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs)
def load_params(executor, dirname, main_program=None, filename=None): def load_params(executor, dirname, main_program=None, filename=None):
""" """
...@@ -659,11 +663,19 @@ def save_inference_model(dirname, ...@@ -659,11 +663,19 @@ def save_inference_model(dirname,
save_persistables(executor, dirname, inference_program, params_filename) save_persistables(executor, dirname, inference_program, params_filename)
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
_save_lookup_tables_by_notify(executor, lookup_table_filename,
main_program._distributed_lookup_table,
main_program._endpoints)
def load_inference_model(dirname, def load_inference_model(dirname,
executor, executor,
model_filename=None, model_filename=None,
params_filename=None): params_filename=None,
pserver_endpoints=None):
""" """
Load inference model from a directory Load inference model from a directory
...@@ -679,6 +691,10 @@ def load_inference_model(dirname, ...@@ -679,6 +691,10 @@ def load_inference_model(dirname,
parameters were saved in a single binary parameters were saved in a single binary
file. If parameters were saved in separate file. If parameters were saved in separate
files, set it as 'None'. files, set it as 'None'.
pserver_endpoints(list|None): This only need by distributed inference.
When use distributed look up table in training,
We also need it in inference.The parameter is
a list of pserver endpoints.
Returns: Returns:
tuple: The return of this function is a tuple with three elements: tuple: The return of this function is a tuple with three elements:
...@@ -697,12 +713,16 @@ def load_inference_model(dirname, ...@@ -697,12 +713,16 @@ def load_inference_model(dirname,
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
path = "./infer_model" path = "./infer_model"
endpoints = ["127.0.0.1:2023","127.0.0.1:2024"]
[inference_program, feed_target_names, fetch_targets] = [inference_program, feed_target_names, fetch_targets] =
fluid.io.load_inference_model(dirname=path, executor=exe) fluid.io.load_inference_model(dirname=path, executor=exe)
results = exe.run(inference_program, results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img}, feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets) fetch_list=fetch_targets)
# if we need lookup table, we will use:
fluid.io.load_inference_model(dirname=path, executor=exe, pserver_endpoints=endpoints)
# In this exsample, the inference program was saved in the # In this exsample, the inference program was saved in the
# "./infer_model/__model__" and parameters were saved in # "./infer_model/__model__" and parameters were saved in
# separate files in ""./infer_model". # separate files in ""./infer_model".
...@@ -729,6 +749,9 @@ def load_inference_model(dirname, ...@@ -729,6 +749,9 @@ def load_inference_model(dirname,
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
load_persistables(executor, dirname, program, params_filename) load_persistables(executor, dirname, program, params_filename)
if pserver_endpoints:
program = _endpoints_replacement(program, pserver_endpoints)
feed_target_names = program.desc.get_feed_target_names() feed_target_names = program.desc.get_feed_target_names()
fetch_target_names = program.desc.get_fetch_target_names() fetch_target_names = program.desc.get_fetch_target_names()
fetch_targets = [ fetch_targets = [
...@@ -738,6 +761,61 @@ def load_inference_model(dirname, ...@@ -738,6 +761,61 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets] return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
if op.has_attr(ENDPOINT_MAP):
op.set_attr(ENDPOINT_MAP, endpoints)
program._sync_with_cpp()
return program
def get_parameter_value(para, executor): def get_parameter_value(para, executor):
""" """
Get the LoDTensor value of the given parameter. Get the LoDTensor value of the given parameter.
...@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -799,3 +877,46 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program() program = default_main_program()
var = program.global_block().var(name) var = program.global_block().var(name)
return get_parameter_value(var, executor) return get_parameter_value(var, executor)
def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
if not slice_vars_and_attrs:
return
load_prog = Program()
load_block = load_prog.global_block()
for var_tuple in slice_vars_and_attrs:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + reduce(lambda x, y: x * y, slice_var.shape)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, clone_orig_var.name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
executor.run(load_prog)
...@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase): ...@@ -47,7 +47,6 @@ class TranspilerTest(unittest.TestCase):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def get_main_program(self): def get_main_program(self):
main = fluid.Program() main = fluid.Program()
...@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase): ...@@ -95,8 +94,9 @@ class TranspilerTest(unittest.TestCase):
def test_transpiler(self): def test_transpiler(self):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.unique_name.guard():
self.transpiler_test_impl() with fluid.program_guard(main, startup):
self.transpiler_test_impl()
class TestBasicModel(TranspilerTest): class TestBasicModel(TranspilerTest):
...@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest): ...@@ -249,7 +249,6 @@ class TestLRDecay(TranspilerTest):
decay_rate=0.1, decay_rate=0.1,
staircase=True)) staircase=True))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest): ...@@ -279,7 +278,6 @@ class TestLRDecayConditional(TranspilerTest):
learning_rate=fluid.layers.piecewise_decay([10000, 20000], learning_rate=fluid.layers.piecewise_decay([10000, 20000],
[1.0, 0.5, 1.0])) [1.0, 0.5, 1.0]))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest): ...@@ -328,7 +326,6 @@ class TestL2Decay(TranspilerTest):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest): ...@@ -363,7 +360,6 @@ class TestL2DecayWithPiecewise(TranspilerTest):
momentum=0.9, momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4)) regularization=fluid.regularizer.L2Decay(1e-4))
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest): ...@@ -393,13 +389,14 @@ class TestDistLookupTableBase(TranspilerTest):
def network_with_table(self, is_sparse, is_distributed): def network_with_table(self, is_sparse, is_distributed):
self.table_size = 1000 self.table_size = 1000
self.emb_size = 64 self.emb_size = 64
self.lookup_table_name = 'shared_w'
def emb_pool(ids): def emb_pool(ids):
emb = fluid.layers.embedding( emb = fluid.layers.embedding(
input=ids, input=ids,
size=[self.table_size, self.emb_size], size=[self.table_size, self.emb_size],
dtype='float32', dtype='float32',
param_attr='shared_w', # share parameter param_attr=self.lookup_table_name, # share parameter
is_sparse=is_sparse, is_sparse=is_sparse,
is_distributed=is_distributed) is_distributed=is_distributed)
pool = fluid.layers.sequence_pool(input=emb, pool_type='average') pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
...@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase): ...@@ -572,7 +569,7 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
def transpiler_test_impl(self): def transpiler_test_impl(self):
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) pserver1, _ = self.get_pserver(self.pserver1_ep, config)
self.assertTrue(self.transpiler.has_distributed_lookup_table) self.assertTrue(self.transpiler.has_distributed_lookup_table)
lookup_table_var = pserver1.global_block().vars[ lookup_table_var = pserver1.global_block().vars[
...@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase): ...@@ -582,6 +579,21 @@ class TestDistLookupTableSliceSize(TestDistLookupTableBase):
self.assertEqual(row_size, calc_row_size) self.assertEqual(row_size, calc_row_size)
class TestDistArgsInProgram(TestDistLookupTableBase):
def net_conf(self):
self.network_with_table(is_sparse=True, is_distributed=True)
def transpiler_test_impl(self):
trainer, _ = self.get_trainer()
self.assertTrue(trainer._is_distributed)
self.assertTrue(trainer._is_chief)
self.assertEqual(trainer._distributed_lookup_table,
self.lookup_table_name)
self.assertEqual(trainer._endpoints,
[self.pserver1_ep, self.pserver2_ep])
class TestRMSPropOptimizer(TranspilerTest): class TestRMSPropOptimizer(TranspilerTest):
def net_conf(self): def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32') x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
...@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest): ...@@ -595,7 +607,6 @@ class TestRMSPropOptimizer(TranspilerTest):
avg_cost = fluid.layers.mean(cost) avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1) optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
return
def transpiler_test_impl(self): def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep) pserver, startup = self.get_pserver(self.pserver1_ep)
...@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest): ...@@ -612,5 +623,40 @@ class TestRMSPropOptimizer(TranspilerTest):
self.assertEqual(moment_var.shape, (500, 1000)) self.assertEqual(moment_var.shape, (500, 1000))
class TestLoadSliceVar(TranspilerTest):
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
optimizer = fluid.optimizer.RMSProp(learning_rate=0.1)
optimizer.minimize(avg_cost)
def transpiler_test_impl(self):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) + reduce(
lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -215,6 +215,13 @@ class DistributeTranspiler(object): ...@@ -215,6 +215,13 @@ class DistributeTranspiler(object):
for param_var, grad_var in self.params_grads: for param_var, grad_var in self.params_grads:
self.param_name_to_grad_name[param_var.name] = grad_var.name self.param_name_to_grad_name[param_var.name] = grad_var.name
# add distributed attrs to program
self.origin_program._is_distributed = True
self.origin_program._endpoints = self.pserver_endpoints
self.origin_program._is_chief = self.trainer_id == 0
self.origin_program._distributed_lookup_table = self.table_name if self.table_name else None
# split and create vars, then put splited vars in dicts for later use.
# step 1: split and create vars, then put splited vars in dicts for later use. # step 1: split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars() self._init_splited_vars()
...@@ -590,6 +597,8 @@ class DistributeTranspiler(object): ...@@ -590,6 +597,8 @@ class DistributeTranspiler(object):
checkpoint_block_id = self._create_checkpoint_save_block( checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx) pserver_program, table_opt_block.idx)
pserver_program._distributed_lookup_table = self.table_name
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
...@@ -616,6 +625,10 @@ class DistributeTranspiler(object): ...@@ -616,6 +625,10 @@ class DistributeTranspiler(object):
outputs={}, outputs={},
attrs=attrs) attrs=attrs)
# add distributed attrs
pserver_program._slice_vars_and_attrs = self._get_slice_vars_and_attrs(
endpoint)
pserver_program._sync_with_cpp() pserver_program._sync_with_cpp()
return pserver_program return pserver_program
...@@ -689,8 +702,31 @@ class DistributeTranspiler(object): ...@@ -689,8 +702,31 @@ class DistributeTranspiler(object):
inputs=new_inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.all_attrs()) attrs=op.all_attrs())
# add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
return s_prog return s_prog
def _get_slice_vars_and_attrs(self, endpoint):
slice_vars_and_attrs = []
block_suffix = "block"
for param in self.param_grad_ep_mapping[endpoint]["params"]:
orig_var_name, block_name, _ = self._get_varname_parts(param.name)
if not block_name:
continue
block_idx = int(block_name.split(block_suffix)[1])
orig_var = self.origin_program.global_block().vars[orig_var_name]
skip_numel = 0
slice_vars = self.param_var_mapping[orig_var_name]
for slice_var in slice_vars[:block_idx]:
skip_numel += reduce(lambda x, y: x * y, slice_var.shape)
slice_vars_and_attrs.append([orig_var, skip_numel, param])
return slice_vars_and_attrs
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self): def _has_distributed_lookup_table(self):
......
...@@ -59,8 +59,12 @@ class InferenceTranspiler(object): ...@@ -59,8 +59,12 @@ class InferenceTranspiler(object):
scope = global_scope() scope = global_scope()
if not isinstance(scope, core.Scope): if not isinstance(scope, core.Scope):
raise TypeError("scope should be as Scope type or None") raise TypeError("scope should be as Scope type or None")
self._fuse_batch_norm(program, place, scope) use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
self._fuse_relu_mkldnn(program) if use_mkldnn:
self._fuse_relu_mkldnn(program)
self._fuse_conv_bias_mkldnn(program)
else:
self._fuse_batch_norm(program, place, scope)
def _fuse_relu_mkldnn(self, program): def _fuse_relu_mkldnn(self, program):
''' '''
...@@ -82,10 +86,6 @@ class InferenceTranspiler(object): ...@@ -82,10 +86,6 @@ class InferenceTranspiler(object):
:param program: program to transpile :param program: program to transpile
:type program: Program :type program: Program
''' '''
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
if not use_mkldnn:
return
self.block = program.block(0) self.block = program.block(0)
i = 0 i = 0
...@@ -106,6 +106,69 @@ class InferenceTranspiler(object): ...@@ -106,6 +106,69 @@ class InferenceTranspiler(object):
# And a better solution will be considered later. # And a better solution will be considered later.
program = program.clone() program = program.clone()
def _fuse_conv_bias_mkldnn(self, program):
'''
Transpile the program by fused convolution and elementwise_add.
Replace conv2d and elementwise_add ops with a new conv2d op
based on an old conv2d op and the :math:`Bias` taken from
elementwise_add.
For input :math:`X`:
- Conv process: :math:`X = input * W`
- Elementwise_add process: :math` X = X + bias`
After fuse into one operation:
.. math::
X = input * W + bias
The operator transformation is:
- before:
- conv->elementwise_add->any_other_op
- after:
- conv->any_other_op
The transpile stages are:
1. Extract bias and output variables from elementwise_add.
2. Extract Input, Weight and attributes from conv op.
3. Create a new convolution op based on extracted params.
4. Remove old conv op.
5. Remove elementwise_add.
5. Remove unused variables.
Args:
program (Program): program to transpile
'''
self.block = program.block(0)
i = 0
while i < len(self.block.ops) - 2:
current_op = self.block.ops[i]
next_op = self.block.ops[i + 1]
# conv2d with bias
if current_op.type in ['conv2d'] and \
next_op.type in ['elementwise_add']:
self._fuse_conv_bias(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
i = i + 1
self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program.clone()
def _fuse_batch_norm(self, program, place, scope): def _fuse_batch_norm(self, program, place, scope):
''' '''
Transpile the program by fused batch normalization. Transpile the program by fused batch normalization.
...@@ -185,7 +248,6 @@ class InferenceTranspiler(object): ...@@ -185,7 +248,6 @@ class InferenceTranspiler(object):
self.block._remove_op(i + 2) self.block._remove_op(i + 2)
i = i + 1 i = i + 1
i = i + 1 i = i + 1
self._adjust_input() self._adjust_input()
self._remove_unused_var() self._remove_unused_var()
# TODO(luotao): use clone() method to flush the program.desc in force, # TODO(luotao): use clone() method to flush the program.desc in force,
...@@ -288,6 +350,33 @@ class InferenceTranspiler(object): ...@@ -288,6 +350,33 @@ class InferenceTranspiler(object):
# collect the renamed input # collect the renamed input
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0] self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]
def _fuse_conv_bias(self, index, conv_op, elementwise_add_op):
'''
fuse the conv op with elementwise_add
:param index: index of the conv_op in ops list
:type index: Int
:param conv_op: convolution operator
:type conv_op: Operator
:param elementwise_add_op: convolution's bias operator
:type elementwise_add_op: Operator
'''
bias_var = self.block.var(elementwise_add_op.input("Y")[0])
out_var = self.block.var(elementwise_add_op.output("Out")[0])
filter_var = self.block.var(conv_op.input("Filter")[0])
in_var = self.block.var(conv_op.input("Input")[0])
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op(
index,
type="conv2d",
inputs={"Input": in_var,
"Filter": filter_var,
"Bias": bias_var},
outputs={"Output": out_var},
attrs=attrs)
def _adjust_input(self): def _adjust_input(self):
for i in range(len(self.block.ops)): for i in range(len(self.block.ops)):
current_op = self.block.ops[i] current_op = self.block.ops[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册