提交 359e2148 编写于 作者: W wangguibao

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into async_executor

...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
| qingqing01 | Qing-Qing Dang | | qingqing01 | Qing-Qing Dang |
| reyoung | Yang Yu | | reyoung | Yang Yu |
| Superjom | Chun-Wei Yan | | Superjom | Chun-Wei Yan |
| tensor-tang | Jian Tang |
| tianbingsz | Tian-Bing Xu | | tianbingsz | Tian-Bing Xu |
| tpatejko | Tomasz Patejko | | tpatejko | Tomasz Patejko |
| typhoonzero | Yi Wu | | typhoonzero | Yi Wu |
......
...@@ -164,7 +164,7 @@ endif() ...@@ -164,7 +164,7 @@ endif()
set(module "inference") set(module "inference")
copy(inference_lib DEPS ${inference_deps} copy(inference_lib DEPS ${inference_deps}
SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.* SRCS ${src_dir}/${module}/*.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${src_dir}/${module}/api/paddle_inference_api.h ${src_dir}/${module}/api/paddle_*.h
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h
DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} DSTS ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module} ${dst_dir}/${module}
) )
...@@ -202,10 +202,10 @@ copy(third_party DEPS fluid_lib_dist ...@@ -202,10 +202,10 @@ copy(third_party DEPS fluid_lib_dist
DSTS ${FLUID_INFERENCE_INSTALL_DIR} ${FLUID_INFERENCE_INSTALL_DIR} DSTS ${FLUID_INFERENCE_INSTALL_DIR} ${FLUID_INFERENCE_INSTALL_DIR}
) )
# only need libpaddle_fluid.so/a and paddle_inference_api.h for inference-only library # only need libpaddle_fluid.so/a and paddle_*.h for inference-only library
copy(inference_api_lib DEPS fluid_lib_dist copy(inference_api_lib DEPS fluid_lib_dist
SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.* SRCS ${FLUID_INSTALL_DIR}/paddle/fluid/inference/libpaddle_fluid.*
${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_inference_api.h ${FLUID_INSTALL_DIR}/paddle/fluid/inference/paddle_*.h
DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include DSTS ${FLUID_INFERENCE_INSTALL_DIR}/paddle/lib ${FLUID_INFERENCE_INSTALL_DIR}/paddle/include
) )
......
...@@ -34,4 +34,5 @@ if(TENSORRT_FOUND) ...@@ -34,4 +34,5 @@ if(TENSORRT_FOUND)
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
include_directories(${TENSORRT_INCLUDE_DIR}) include_directories(${TENSORRT_INCLUDE_DIR})
list(APPEND EXTERNAL_LIBS ${TENSORRT_LIBRARY}) list(APPEND EXTERNAL_LIBS ${TENSORRT_LIBRARY})
add_definitions(-DPADDLE_WITH_TENSORRT)
endif() endif()
...@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's ...@@ -103,7 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)) paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode'], varargs=None, keywords=None, defaults=(False, -100, False)) paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, False, False))
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1)) paddle.fluid.layers.autoincreased_step_counter ArgSpec(args=['counter_name', 'begin', 'step'], varargs=None, keywords=None, defaults=(None, 1, 1))
...@@ -274,6 +274,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k ...@@ -274,6 +274,7 @@ paddle.fluid.layers.hard_shrink ArgSpec(args=['x', 'threshold'], varargs=None, k
paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.cumsum ArgSpec(args=['x', 'axis', 'exclusive', 'reverse'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.thresholded_relu ArgSpec(args=['x', 'threshold'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False)) paddle.fluid.layers.prior_box ArgSpec(args=['input', 'image', 'min_sizes', 'max_sizes', 'aspect_ratios', 'variance', 'flip', 'clip', 'steps', 'offset', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, [1.0], [0.1, 0.1, 0.2, 0.2], False, False, [0.0, 0.0], 0.5, None, False))
paddle.fluid.layers.density_prior_box ArgSpec(args=['input', 'image', 'densities', 'fixed_sizes', 'fixed_ratios', 'variance', 'clip', 'steps', 'offset', 'name'], varargs=None, keywords=None, defaults=(None, None, None, [0.1, 0.1, 0.2, 0.2], False, [0.0, 0.0], 0.5, None))
paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False)) paddle.fluid.layers.multi_box_head ArgSpec(args=['inputs', 'image', 'base_size', 'num_classes', 'aspect_ratios', 'min_ratio', 'max_ratio', 'min_sizes', 'max_sizes', 'steps', 'step_w', 'step_h', 'offset', 'variance', 'flip', 'clip', 'kernel_size', 'pad', 'stride', 'name', 'min_max_aspect_ratios_order'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None, None, 0.5, [0.1, 0.1, 0.2, 0.2], True, False, 1, 0, 1, None, False))
paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.bipartite_match ArgSpec(args=['dist_matrix', 'match_type', 'dist_threshold', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.target_assign ArgSpec(args=['input', 'matched_indices', 'negative_indices', 'mismatch_value', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
......
...@@ -30,8 +30,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( ...@@ -30,8 +30,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
local_scopes_(local_scopes), local_scopes_(local_scopes),
places_(places), places_(places),
graph_(std::move(graph)), graph_(std::move(graph)),
pool_(strategy.num_threads_ + pool_(strategy.num_threads_),
1), // add one more thread for generate op_deps prepare_pool_(1), // add one more thread for generate op_deps
fetch_ctxs_(places) { fetch_ctxs_(places) {
for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) { for (auto &op : ir::FilterByNodeWrapper<OpHandleBase>(*graph_)) {
int dep = static_cast<int>(op->NotReadyInputSize()); int dep = static_cast<int>(op->NotReadyInputSize());
...@@ -160,7 +160,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync( ...@@ -160,7 +160,7 @@ void FastThreadedSSAGraphExecutor::RunOpAsync(
}); });
} }
void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() { void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
atomic_op_deps_ = pool_.enqueue([&] { atomic_op_deps_ = prepare_pool_.enqueue([&] {
auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>; auto *op_deps = new std::unordered_map<OpHandleBase *, std::atomic<int>>;
for (auto &pair : op_deps_) { for (auto &pair : op_deps_) {
(*op_deps)[pair.first] = pair.second; (*op_deps)[pair.first] = pair.second;
......
...@@ -46,6 +46,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -46,6 +46,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::vector<OpHandleBase *> bootstrap_ops_; std::vector<OpHandleBase *> bootstrap_ops_;
::ThreadPool pool_; ::ThreadPool pool_;
::ThreadPool prepare_pool_;
platform::DeviceContextPool fetch_ctxs_; platform::DeviceContextPool fetch_ctxs_;
std::atomic<int> remaining_; std::atomic<int> remaining_;
......
...@@ -359,6 +359,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -359,6 +359,7 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
bool keep_kids) { bool keep_kids) {
PADDLE_ENFORCE_NOT_NULL(scope);
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
if (create_local_scope) { if (create_local_scope) {
......
...@@ -5,6 +5,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") ...@@ -5,6 +5,7 @@ file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
# Usage: pass_library(target inference) will append to paddle_inference_pass.h # Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
function(pass_library TARGET DEST) function(pass_library TARGET DEST)
set(options "") set(options "")
set(oneValueArgs "") set(oneValueArgs "")
...@@ -15,10 +16,11 @@ function(pass_library TARGET DEST) ...@@ -15,10 +16,11 @@ function(pass_library TARGET DEST)
if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") if (${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
message(STATUS "add pass ${TARGET} ${DEST}") message(STATUS "add pass ${TARGET} ${DEST}")
file(APPEND ${pass_file} "USE_PASS(${TARGET});\n") file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
set(PASS_LIBRARY ${TARGET} ${PASS_LIBRARY} PARENT_SCOPE) set(INFER_IR_PASSES ${INFER_IR_PASSES} ${TARGET} CACHE INTERNAL "")
endif() endif()
endfunction() endfunction()
cc_library(node SRCS node.cc DEPS proto_desc) cc_library(node SRCS node.cc DEPS proto_desc)
cc_library(graph SRCS graph.cc DEPS node pretty_log) cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
......
...@@ -91,10 +91,10 @@ void FindWhileOp(Graph* graph) { ...@@ -91,10 +91,10 @@ void FindWhileOp(Graph* graph) {
#undef OP_SET_IN #undef OP_SET_IN
#undef OP_SET_OUT #undef OP_SET_OUT
auto* X = graph->RetriveNode(34); auto* X = graph->RetrieveNode(34);
auto* LSTMOUT = graph->RetriveNode(81); auto* LSTMOUT = graph->RetrieveNode(81);
auto* cell_init = graph->RetriveNode(6); auto* cell_init = graph->RetrieveNode(6);
auto* hidden_init = graph->RetriveNode(8); auto* hidden_init = graph->RetrieveNode(8);
auto* lstm_op = graph->CreateOpNode(&op_desc); auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param); PrepareParameters(graph, param);
......
...@@ -84,8 +84,6 @@ void CheckProgram(const ProgramDesc &program) { ...@@ -84,8 +84,6 @@ void CheckProgram(const ProgramDesc &program) {
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
CheckProgram(program_); CheckProgram(program_);
// Make the nodes id start from 0.
Node::ResetId();
auto var_nodes = InitFromProgram(program_); auto var_nodes = InitFromProgram(program_);
ResolveHazard(var_nodes); ResolveHazard(var_nodes);
} }
......
...@@ -116,13 +116,17 @@ class Graph { ...@@ -116,13 +116,17 @@ class Graph {
// Create a normal variable with non-null VarDesc. // Create a normal variable with non-null VarDesc.
ir::Node *CreateVarNode(VarDesc *var_desc) { ir::Node *CreateVarNode(VarDesc *var_desc) {
PADDLE_ENFORCE(var_desc); PADDLE_ENFORCE(var_desc);
return AddNode(new ir::Node(var_desc)); auto *x = AddNode(new ir::Node(var_desc));
x->SetId(num_node_created_++);
return x;
} }
// Create a normal runnable operator with OpDesc. // Create a normal runnable operator with OpDesc.
ir::Node *CreateOpNode(OpDesc *op_desc) { ir::Node *CreateOpNode(OpDesc *op_desc) {
PADDLE_ENFORCE(op_desc); PADDLE_ENFORCE(op_desc);
return AddNode(new ir::Node(op_desc)); auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++);
return x;
} }
// Create a control dependency var that connects 2 operations. The // Create a control dependency var that connects 2 operations. The
...@@ -132,13 +136,17 @@ class Graph { ...@@ -132,13 +136,17 @@ class Graph {
// TODO(panyx0718): control var name should be really unique. // TODO(panyx0718): control var name should be really unique.
const std::string name = string::Sprintf( const std::string name = string::Sprintf(
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size()); "%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
return AddNode(new ir::Node(name, ir::Node::Type::kVariable)); auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable));
x->SetId(num_node_created_++);
return x;
} }
// A more free style way of creating a graph node. Mostly use for test // A more free style way of creating a graph node. Mostly use for test
// or "copy" from another node. Avoid using it if possible. // or "copy" from another node. Avoid using it if possible.
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) { ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
return AddNode(new ir::Node(name, type)); auto *x = AddNode(new ir::Node(name, type));
x->SetId(num_node_created_++);
return x;
} }
// Clear all node information of the graph and return the ownership of the // Clear all node information of the graph and return the ownership of the
...@@ -160,7 +168,7 @@ class Graph { ...@@ -160,7 +168,7 @@ class Graph {
} }
// NOTE low performance, but simple and secure. // NOTE low performance, but simple and secure.
Node *RetriveNode(int id) { Node *RetrieveNode(int id) {
for (auto &node : nodes_) { for (auto &node : nodes_) {
if (node.second->id() == id) { if (node.second->id() == id) {
return node.second.get(); return node.second.get();
...@@ -169,6 +177,7 @@ class Graph { ...@@ -169,6 +177,7 @@ class Graph {
return nullptr; return nullptr;
} }
const ProgramDesc &program() const { return program_; }
std::map<std::string, std::vector<ir::Node *>> InitFromProgram( std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program); const ProgramDesc &program);
...@@ -190,6 +199,7 @@ class Graph { ...@@ -190,6 +199,7 @@ class Graph {
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_; std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
std::unordered_set<ir::Node *> node_set_; std::unordered_set<ir::Node *> node_set_;
size_t num_node_created_{0}; // help to generate a unique node id.
}; };
bool IsControlDepVar(const ir::Node &var); bool IsControlDepVar(const ir::Node &var);
......
...@@ -167,10 +167,12 @@ struct HitGroup { ...@@ -167,10 +167,12 @@ struct HitGroup {
bool Match(Node *node, PDNode *pat) { bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) { if (nodes_.count(node)) {
if (!roles.count(pat)) return false; if (roles.count(pat) && roles[pat] == node) return true;
return roles[pat] == node; return false;
} else {
if (roles.count(pat) && roles[pat] != node) return false;
return true;
} }
return !roles.count(pat) || roles.at(pat) == node;
} }
void Register(Node *node, PDNode *pat) { void Register(Node *node, PDNode *pat) {
...@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() { ...@@ -198,7 +200,6 @@ GraphPatternDetector::DetectPatterns() {
std::vector<GraphPatternDetector::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::vector<HitGroup> init_groups; std::vector<HitGroup> init_groups;
std::array<std::vector<HitGroup>, 2> bi_records; std::array<std::vector<HitGroup>, 2> bi_records;
// PADDLE_ENFORCE(!pattern_.edges().empty(), "At least one edge is needed");
auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get() auto *first_pnode = pattern_.edges().empty() ? pattern().nodes().front().get()
: pattern_.edges().front().first; : pattern_.edges().front().first;
if (!pdnodes2nodes_.count(first_pnode)) return result; if (!pdnodes2nodes_.count(first_pnode)) return result;
...@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() { ...@@ -228,11 +229,12 @@ GraphPatternDetector::DetectPatterns() {
VLOG(80) << "check " << source->id() << " -- " << target->id(); VLOG(80) << "check " << source->id() << " -- " << target->id();
// TODO(Superjomn) add some prune strategies. // TODO(Superjomn) add some prune strategies.
for (const auto &group : pre_groups) { for (const auto &group : pre_groups) {
if (IsNodesLink(source, target)) {
HitGroup new_group = group; HitGroup new_group = group;
if (IsNodesLink(source, target) && bool flag = new_group.Match(source, edge.first) &&
new_group.Match(source, edge.first)) { new_group.Match(target, edge.second);
if (flag) {
new_group.Register(source, edge.first); new_group.Register(source, edge.first);
if (new_group.Match(target, edge.second)) {
new_group.Register(target, edge.second); new_group.Register(target, edge.second);
cur_groups.push_back(new_group); cur_groups.push_back(new_group);
// TODO(Superjomn) need to unique // TODO(Superjomn) need to unique
......
...@@ -310,8 +310,8 @@ void GraphSafeRemoveNodes(Graph* graph, ...@@ -310,8 +310,8 @@ void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes); const std::unordered_set<const Node*>& nodes);
// Some pre-defined patterns those can be reused in multiple passes. // Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better reusage // The related Fluid Layer or Op should be one pattern here for better re-usage
// accross different fusion. // across different fusion.
namespace patterns { namespace patterns {
struct KeyCounter { struct KeyCounter {
......
...@@ -35,10 +35,11 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl( ...@@ -35,10 +35,11 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
new proto::ProgramDesc(*program.Proto())); new proto::ProgramDesc(*program.Proto()));
auto block = program_pb->mutable_blocks(kRootBlockIndex); auto block = program_pb->mutable_blocks(kRootBlockIndex);
block->set_idx(kRootBlockIndex);
block->clear_vars(); block->clear_vars();
std::unordered_set<std::string> visited_vars; std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) { for (ir::Node* n : graph->Nodes()) {
if (n->NodeType() == ir::Node::Type::kVariable) { if (n->IsVar()) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) { if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
visited_vars.insert(n->Var()->Name()); visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto()); block->add_vars()->MergeFrom(*n->Var()->Proto());
......
...@@ -66,6 +66,76 @@ NodesDFSIterator &NodesDFSIterator::operator=(const NodesDFSIterator &other) { ...@@ -66,6 +66,76 @@ NodesDFSIterator &NodesDFSIterator::operator=(const NodesDFSIterator &other) {
} }
Node *NodesDFSIterator::operator->() { return stack_.top(); } Node *NodesDFSIterator::operator->() { return stack_.top(); }
inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
return node.inputs.size() == n;
}
NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
PADDLE_ENFORCE(!source.empty(),
"Start points of topological sorting should not be empty!");
// CHECK all the inputs' in-degree is 0
for (auto *node : source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
}
std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited;
while (!to_visit.empty()) {
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
for (auto *p : queue) {
inlink_visited.clear();
std::copy_if(p->inputs.begin(), p->inputs.end(),
std::back_inserter(inlink_visited),
[&](Node *x) -> bool { return visited.count(x) != 0; });
if (inlink_visited.size() == p->inputs.size()) {
sorted_.push_back(p);
for (auto *_ : p->outputs) {
if (!visited.count(_)) {
to_visit.insert(_);
}
}
to_visit.erase(p);
visited.insert(p);
}
}
}
}
NodesTSIterator::NodesTSIterator(const NodesTSIterator &other)
: sorted_(other.sorted_), cursor_(other.cursor_) {}
Node &NodesTSIterator::operator*() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
return *sorted_[cursor_];
}
NodesTSIterator &NodesTSIterator::operator++() {
if (++cursor_ >= sorted_.size()) {
sorted_.clear();
cursor_ = 0;
}
return *this;
}
NodesTSIterator &NodesTSIterator::operator=(const NodesTSIterator &other) {
cursor_ = other.cursor_;
sorted_ = other.sorted_;
return *this;
}
bool NodesTSIterator::operator==(const NodesTSIterator &other) {
return sorted_ == other.sorted_ && cursor_ == other.cursor_;
}
Node *NodesTSIterator::operator->() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
return sorted_[cursor_];
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -62,6 +62,32 @@ struct NodesDFSIterator ...@@ -62,6 +62,32 @@ struct NodesDFSIterator
std::unordered_set<Node *> visited_; std::unordered_set<Node *> visited_;
}; };
// Topological sorting iterator on nodes.
struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, Node *> {
NodesTSIterator() = default;
NodesTSIterator(const std::vector<Node *> &source);
NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0;
}
NodesTSIterator(const NodesTSIterator &other);
Node &operator*();
NodesTSIterator &operator++();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesTSIterator &operator=(const NodesTSIterator &other);
bool operator==(const NodesTSIterator &other);
bool operator!=(const NodesTSIterator &other) { return !(*this == other); }
Node *operator->();
private:
std::vector<Node *> sorted_;
size_t cursor_{0};
};
/* /*
* GraphTraits contains some graph traversal algorithms. * GraphTraits contains some graph traversal algorithms.
* *
...@@ -76,6 +102,14 @@ struct GraphTraits { ...@@ -76,6 +102,14 @@ struct GraphTraits {
NodesDFSIterator()); NodesDFSIterator());
} }
static iterator_range<NodesTSIterator> TS(const Graph &g) {
auto start_points = ExtractStartPoints(g);
PADDLE_ENFORCE(!start_points.empty());
NodesTSIterator x(start_points);
return iterator_range<NodesTSIterator>(NodesTSIterator(start_points),
NodesTSIterator());
}
private: private:
// The nodes those have no input will be treated as start points. // The nodes those have no input will be treated as start points.
static std::vector<Node *> ExtractStartPoints(const Graph &g) { static std::vector<Node *> ExtractStartPoints(const Graph &g) {
......
...@@ -18,7 +18,6 @@ namespace paddle { ...@@ -18,7 +18,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
constexpr char Node::kControlDepVarName[]; constexpr char Node::kControlDepVarName[];
int Node::count_ = 0;
std::unique_ptr<Node> CreateNodeForTest(const std::string& name, std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type) { Node::Type type) {
......
...@@ -115,37 +115,30 @@ class Node { ...@@ -115,37 +115,30 @@ class Node {
int id_; int id_;
private: private:
// ID can only set by a Graph.
void SetId(int id) { id_ = id; }
friend class Graph; friend class Graph;
friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name, friend std::unique_ptr<Node> CreateNodeForTest(const std::string& name,
Node::Type type); Node::Type type);
explicit Node(const std::string& name, Type type) explicit Node(const std::string& name, Type type)
: name_(name), : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {}
var_desc_(nullptr),
op_desc_(nullptr),
type_(type),
id_(count_++) {}
explicit Node(VarDesc* var_desc) explicit Node(VarDesc* var_desc)
: name_(var_desc->Name()), : name_(var_desc->Name()),
var_desc_(new VarDesc(*var_desc)), var_desc_(new VarDesc(*var_desc)),
op_desc_(nullptr), op_desc_(nullptr),
type_(Type::kVariable), type_(Type::kVariable) {}
id_(count_++) {}
explicit Node(OpDesc* op_desc) explicit Node(OpDesc* op_desc)
: name_(op_desc->Type()), : name_(op_desc->Type()),
var_desc_(nullptr), var_desc_(nullptr),
op_desc_(new OpDesc(*op_desc, op_desc->Block())), op_desc_(new OpDesc(*op_desc, op_desc->Block())),
type_(Type::kOperation), type_(Type::kOperation) {}
id_(count_++) {}
Node() = delete; Node() = delete;
static int count_;
// Please don't use this API or make this public.
static void ResetId() { count_ = 0; }
boost::any wrapper_; boost::any wrapper_;
std::function<void(void)> wrapper_deleter_; std::function<void(void)> wrapper_deleter_;
std::type_index wrapper_type_ = std::type_index(typeid(void)); std::type_index wrapper_type_ = std::type_index(typeid(void));
......
...@@ -93,6 +93,7 @@ class Pass { ...@@ -93,6 +93,7 @@ class Pass {
protected: protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const { virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
LOG(FATAL) << "Calling virtual Pass not implemented."; LOG(FATAL) << "Calling virtual Pass not implemented.";
return graph;
} }
private: private:
......
...@@ -57,59 +57,57 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) { ...@@ -57,59 +57,57 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
} }
} }
void NaiveExecutor::Prepare(Scope *parent_scope, void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
const ProgramDesc &program_desc, int block_id, int block_id, bool with_feed_fetch_ops) {
bool with_feed_fetch_ops) { if (!scope) {
if (!parent_scope) {
scope_ = new framework::Scope; scope_ = new framework::Scope;
} else { } else {
scope_ = &parent_scope->NewScope(); scope_ = scope;
} }
CreateVariables(program_desc, scope_, block_id);
VLOG(3) << "NaiveExecutor init with scope " << scope;
CreateOps(program_desc, block_id, with_feed_fetch_ops); CreateOps(program_desc, block_id, with_feed_fetch_ops);
} }
void NaiveExecutor::Run() { void NaiveExecutor::Run() {
for (auto &op : ops_) { for (auto &op : ops_) {
VLOG(40) << "run " << op->Type(); VLOG(3) << std::this_thread::get_id() << " run " << op->Type()
<< " on scope " << scope_;
op->Run(*scope_, place_); op->Run(*scope_, place_);
} }
} }
void NaiveExecutor::CreateVariables(const ProgramDesc &desc, Scope *scope, void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
int block_id) { bool persistable, Scope *scope) {
PADDLE_ENFORCE(scope); PADDLE_ENFORCE_NOT_NULL(scope);
auto &global_block = desc.Block(block_id); auto &global_block = desc.Block(block_id);
const Scope *ancestor_scope = scope; const auto *anc = scope;
while (ancestor_scope->parent()) { PADDLE_ENFORCE(anc->parent() != anc);
ancestor_scope = ancestor_scope->parent(); while (anc->parent()) {
anc = anc->parent();
} }
if (ancestor_scope != scope) {
for (auto &var : global_block.AllVars()) { for (auto &var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) { if (var->Name() == framework::kEmptyVarName) {
continue; continue;
} }
// Create persistable vars in ancestor scope.
if (var->Persistable()) { if (persistable == var->Persistable()) {
auto *ptr = const_cast<Scope *>(ancestor_scope)->Var(var->Name()); if (persistable) {
InitializeVariable(ptr, var->GetType()); if (!anc->FindVar(var->Name())) {
VLOG(30) << "Create Variable " << var->Name() auto *ptr = const_cast<Scope *>(anc)->Var(var->Name());
<< " global, which pointer is " << ptr; VLOG(3) << scope << " Create persistable variable " << var->Name()
} else { // Create temporary variables in local scope. << ", which pointer is " << ptr;
auto *ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(30) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
} }
} else { } else {
for (auto &var : global_block.AllVars()) { auto *ptr = const_cast<Scope *>(scope)->Var(var->Name());
auto *ptr = scope->Var(var->Name()); VLOG(3) << scope << " Create variable " << var->Name()
<< ", which pointer is " << ptr;
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
VLOG(30) << "Create variable " << var->Name() << ", which pointer is " }
<< ptr;
} }
} }
} }
......
...@@ -35,8 +35,14 @@ class NaiveExecutor { ...@@ -35,8 +35,14 @@ class NaiveExecutor {
// Create child scope. // Create child scope.
// Create variables. // Create variables.
// @with_feed_fetch_ops: whether to work with the feed and fetch operators. // @with_feed_fetch_ops: whether to work with the feed and fetch operators.
void Prepare(Scope* parent_scope, const ProgramDesc& program_desc, void Prepare(Scope* scope, const ProgramDesc& program_desc, int block_id,
int block_id, bool with_feed_fetch_ops); bool with_feed_fetch_ops);
// Create variables before head.
// Create parameters if persistable is ture, or create the temporary variables
// instead.
void CreateVariables(const ProgramDesc& desc, int block_id, bool persistable,
Scope* scope);
// Run all the operators. // Run all the operators.
void Run(); void Run();
...@@ -49,8 +55,6 @@ class NaiveExecutor { ...@@ -49,8 +55,6 @@ class NaiveExecutor {
void CleanFeedFetchOps(); void CleanFeedFetchOps();
protected: protected:
void CreateVariables(const ProgramDesc& desc, Scope* scope, int block_id);
void CreateOps(const ProgramDesc& desc, int block_id, void CreateOps(const ProgramDesc& desc, int block_id,
bool with_feed_fetch_ops); bool with_feed_fetch_ops);
......
...@@ -39,7 +39,7 @@ TEST(NaiveExecutor, Basic) { ...@@ -39,7 +39,7 @@ TEST(NaiveExecutor, Basic) {
auto place = platform::CPUPlace(); auto place = platform::CPUPlace();
NaiveExecutor exe(place); NaiveExecutor exe(place);
exe.Prepare(nullptr, program, 0, false /*with feed fetch ops*/); exe.Prepare(nullptr, program, 0, false);
auto* a_tensor = exe.FindTensor("a"); auto* a_tensor = exe.FindTensor("a");
auto* b_tensor = exe.FindTensor("b"); auto* b_tensor = exe.FindTensor("b");
auto* c_tensor = exe.FindTensor("c"); auto* c_tensor = exe.FindTensor("c");
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <queue>
#include <set> #include <set>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
...@@ -36,6 +38,16 @@ DEFINE_double( ...@@ -36,6 +38,16 @@ DEFINE_double(
"Memory size threshold (GB) when the garbage collector clear tensors." "Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0"); "Disabled when this value is less than 0");
// When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue.
// So the mutex is disabled when `ON_INFER`.
#ifdef ON_INFER
#define SCOPE_LOCK_GUARD
#else
#define SCOPE_LOCK_GUARD std::lock_guard<std::mutex> lock(mutex_);
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -49,18 +61,18 @@ int64_t GetEagerDeletionThreshold() { ...@@ -49,18 +61,18 @@ int64_t GetEagerDeletionThreshold() {
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
kids_.push_back(new Scope(this)); kids_.push_back(new Scope(this));
return *kids_.back(); return *kids_.back();
} }
Variable* Scope::Var(const std::string& name) { Variable* Scope::Var(const std::string& name) {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
return VarInternal(name); return VarInternal(name);
} }
Variable* Scope::Var(std::string* name) { Variable* Scope::Var(std::string* name) {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
if (name != nullptr) { if (name != nullptr) {
*name = new_name; *name = new_name;
...@@ -69,34 +81,34 @@ Variable* Scope::Var(std::string* name) { ...@@ -69,34 +81,34 @@ Variable* Scope::Var(std::string* name) {
} }
Variable* Scope::FindVar(const std::string& name) const { Variable* Scope::FindVar(const std::string& name) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
return FindVarInternal(name); return FindVarInternal(name);
} }
Variable* Scope::FindLocalVar(const std::string& name) const { Variable* Scope::FindLocalVar(const std::string& name) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
return FindVarLocally(name); return FindVarLocally(name);
} }
const Scope* Scope::FindScope(const Variable* var) const { const Scope* Scope::FindScope(const Variable* var) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
return FindScopeInternal(var); return FindScopeInternal(var);
} }
void Scope::DropKids() { void Scope::DropKids() {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
kids_.clear(); kids_.clear();
} }
bool Scope::HasKid(const Scope* scope) const { bool Scope::HasKid(const Scope* scope) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
return it != this->kids_.end(); return it != this->kids_.end();
} }
std::vector<std::string> Scope::LocalVarNames() const { std::vector<std::string> Scope::LocalVarNames() const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
std::vector<std::string> known_vars; std::vector<std::string> known_vars;
known_vars.reserve(this->vars_.size()); known_vars.reserve(this->vars_.size());
for (auto& p : vars_) { for (auto& p : vars_) {
...@@ -106,9 +118,10 @@ std::vector<std::string> Scope::LocalVarNames() const { ...@@ -106,9 +118,10 @@ std::vector<std::string> Scope::LocalVarNames() const {
} }
void Scope::DeleteScope(Scope* scope) const { void Scope::DeleteScope(Scope* scope) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope); auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope); PADDLE_ENFORCE(it != this->kids_.end(), "%p Cannot find %p as kid scope",
this, scope);
this->kids_.erase(it); this->kids_.erase(it);
// When making memory benchmark on Fluid, we have to delete scope sync. // When making memory benchmark on Fluid, we have to delete scope sync.
if (FLAGS_benchmark || FLAGS_eager_delete_scope) { if (FLAGS_benchmark || FLAGS_eager_delete_scope) {
...@@ -119,7 +132,7 @@ void Scope::DeleteScope(Scope* scope) const { ...@@ -119,7 +132,7 @@ void Scope::DeleteScope(Scope* scope) const {
} }
void Scope::EraseVars(const std::vector<std::string>& var_names) { void Scope::EraseVars(const std::vector<std::string>& var_names) {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
std::set<std::string> var_set(var_names.begin(), var_names.end()); std::set<std::string> var_set(var_names.begin(), var_names.end());
for (auto it = vars_.begin(); it != vars_.end();) { for (auto it = vars_.begin(); it != vars_.end();) {
if (var_set.find(it->first) != var_set.end()) { if (var_set.find(it->first) != var_set.end()) {
...@@ -132,12 +145,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) { ...@@ -132,12 +145,12 @@ void Scope::EraseVars(const std::vector<std::string>& var_names) {
void Scope::Rename(const std::string& origin_name, void Scope::Rename(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
} }
std::string Scope::Rename(const std::string& origin_name) const { std::string Scope::Rename(const std::string& origin_name) const {
std::lock_guard<std::mutex> lock(mutex_); SCOPE_LOCK_GUARD
auto new_name = string::Sprintf("%p.%d", this, vars_.size()); auto new_name = string::Sprintf("%p.%d", this, vars_.size());
RenameInternal(origin_name, new_name); RenameInternal(origin_name, new_name);
return new_name; return new_name;
...@@ -189,5 +202,46 @@ Variable* Scope::FindVarLocally(const std::string& name) const { ...@@ -189,5 +202,46 @@ Variable* Scope::FindVarLocally(const std::string& name) const {
return nullptr; return nullptr;
} }
std::string GenScopeTreeDebugInfo(Scope* root) {
std::stringstream os;
if (!root) return "";
// level traversal
std::queue<Scope*> queue;
queue.push(root);
std::vector<Scope*> scopes;
while (!queue.empty()) {
auto* end = queue.back();
Scope* q = nullptr;
while (q != end) {
q = queue.front();
queue.pop();
os << q << " ";
scopes.push_back(q);
for (auto* c : q->kids()) {
queue.push(c);
}
}
// end of a level
os << "\n------------------------------------------\n";
}
os << "\nDetails:\n\n";
for (Scope* q : scopes) {
os << "====\n";
os << q << ":\n";
for (auto& var : q->LocalVarNames()) {
os << " - " << var << "\n";
}
}
return os.str();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -78,11 +78,11 @@ class Scope { ...@@ -78,11 +78,11 @@ class Scope {
/// Drop all kids scopes belonged to this scope. /// Drop all kids scopes belonged to this scope.
void DropKids(); void DropKids();
std::list<Scope*>& kids() const { return kids_; }
/// Find if a scope exists in the kid scopes /// Find if a scope exists in the kid scopes
bool HasKid(const Scope* scope) const; bool HasKid(const Scope* scope) const;
const std::list<Scope*>& kids() const { return kids_; }
// enumerate all the variables current contains. // enumerate all the variables current contains.
std::vector<std::string> LocalVarNames() const; std::vector<std::string> LocalVarNames() const;
...@@ -118,12 +118,17 @@ class Scope { ...@@ -118,12 +118,17 @@ class Scope {
// Scope in `kids_` are owned by this class. // Scope in `kids_` are owned by this class.
mutable std::list<Scope*> kids_; mutable std::list<Scope*> kids_;
Scope const* parent_{nullptr}; const Scope* parent_{nullptr};
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
private: private:
mutable std::mutex mutex_; mutable std::mutex mutex_;
}; };
// Generate some debug string about the inherience structure of scope, quite
// naive.
std::string GenScopeTreeDebugInfo(Scope*);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -63,6 +63,26 @@ struct TensorCopyVisitor { ...@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
int64_t size_; int64_t size_;
}; };
struct TensorFillVisitor {
TensorFillVisitor(framework::Tensor* dst, int64_t dst_offset, int64_t size,
float value)
: dst_(dst), dst_offset_(dst_offset), size_(size) {}
template <typename T>
void apply() const {
// TODO(qiao): support other place
platform::CPUPlace cpu;
auto* tensor_data = dst_->mutable_data<T>(cpu);
auto* start = tensor_data + dst_offset_;
auto* end = start + size_;
std::fill(start, end, static_cast<T>(0.0));
}
framework::Tensor* dst_;
int64_t dst_offset_;
int64_t size_;
};
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
...@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const { ...@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
: true; : true;
} }
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) { int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
bool is_test) {
if (is_test) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
rwlock_->RDLock(); rwlock_->RDLock();
auto iter = id_to_index_.find(key); auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) { if (iter == id_to_index_.end()) {
...@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() { ...@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
} }
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown) { bool auto_grown, bool is_test) {
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
if (ids.numel() == 0) { if (ids.numel() == 0) {
...@@ -183,13 +213,21 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -183,13 +213,21 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
"output tensor should have the same shape with table " "output tensor should have the same shape with table "
"except the dims[0]."); "except the dims[0].");
for (int i = 0; i < ids.numel(); ++i) { for (int i = 0; i < ids.numel(); ++i) {
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown); auto id = ids.data<int64_t>()[i];
int64_t index = AutoGrownIndex(id, auto_grown, is_test);
if (index < 0) {
VLOG(5) << "id " << id << " not in the table, return 0";
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
} }
}
} }
} // namespace framework } // namespace framework
......
...@@ -105,7 +105,7 @@ class SelectedRows { ...@@ -105,7 +105,7 @@ class SelectedRows {
* the value * the value
*/ */
void Get(const framework::Tensor& ids, framework::Tensor* value, void Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown = false); bool auto_grown = false, bool is_test = false);
/* /*
* @brief Get the index of the key from id_to_index_ map. If the key not * @brief Get the index of the key from id_to_index_ map. If the key not
...@@ -118,7 +118,7 @@ class SelectedRows { ...@@ -118,7 +118,7 @@ class SelectedRows {
* *
* @return index of the key. * @return index of the key.
*/ */
int64_t AutoGrownIndex(int64_t key, bool auto_grown); int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
void SyncIndex(); void SyncIndex();
......
...@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) { ...@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
data[i * embedding_width + j] = static_cast<float>(i); data[i * embedding_width + j] = static_cast<float>(i);
} }
} }
ASSERT_EQ(table.AutoGrownIndex(10, true), 0); ASSERT_EQ(table.AutoGrownIndex(10, true, false), 0);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1); ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1); ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
ASSERT_EQ(table.AutoGrownIndex(6, true), 2); ASSERT_EQ(table.AutoGrownIndex(6, true, false), 2);
for (int64_t i = 11; i < 20; i++) {
ASSERT_EQ(table.AutoGrownIndex(i, true, true), -1);
ASSERT_TRUE(!table.HasKey(i));
}
ASSERT_TRUE(table.HasKey(10)); ASSERT_TRUE(table.HasKey(10));
ASSERT_TRUE(table.HasKey(8)); ASSERT_TRUE(table.HasKey(8));
ASSERT_TRUE(table.HasKey(6)); ASSERT_TRUE(table.HasKey(6));
......
...@@ -29,13 +29,9 @@ set(SHARED_INFERENCE_SRCS ...@@ -29,13 +29,9 @@ set(SHARED_INFERENCE_SRCS
io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc io.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/api_impl.cc
${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc ${CMAKE_CURRENT_SOURCE_DIR}/api/analysis_predictor.cc
${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc) ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc)
if (WITH_GPU AND TENSORRT_FOUND)
set(STATIC_INFERENCE_APIS ${STATIC_INFERENCE_APIS} paddle_inference_tensorrt_subgraph_engine)
set(SHARED_INFERENCE_SRCS ${SHARED_INFERENCE_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/api/api_tensorrt_subgraph_engine.cc)
endif()
# Create static library # Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array) cc_library(paddle_fluid DEPS ${fluid_modules} ${STATIC_INFERENCE_APIS} zero_copy_tensor reset_tensor_array analysis_config paddle_pass_builder)
if(NOT APPLE) if(NOT APPLE)
# TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac. # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
...@@ -45,7 +41,7 @@ endif() ...@@ -45,7 +41,7 @@ endif()
# Create shared library # Create shared library
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array) DEPS ${fluid_modules} paddle_fluid_api reset_tensor_array analysis_config paddle_pass_builder)
set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) set_target_properties(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
if(NOT APPLE) if(NOT APPLE)
......
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass) unset(analysis_deps CACHE)
set(analysis_deps set(analysis_deps # analysis_deps can be extended accross the project
framework_proto proto_desc ir_pass_manager graph pass paddle_fluid_api executor pretty_log) framework_proto proto_desc graph pass paddle_fluid_api executor pretty_log
ir_pass_manager
CACHE INTERNAL "")
cc_library(analysis SRCS pass_manager.cc node.cc data_flow_graph.cc graph_traits.cc subgraph_splitter.cc add_subdirectory(ir_passes)
add_subdirectory(passes)
cc_library(ir_pass_manager SRCS ir_pass_manager.cc DEPS graph pass ${INFER_IR_PASSES})
cc_library(argument SRCS argument.cc DEPS scope proto_desc)
cc_library(analysis_pass SRCS analysis_pass.cc DEPS proto_desc)
cc_library(analysis SRCS
analyzer.cc analyzer.cc
helper.cc helper.cc
# passes analysis_pass
analysis_pass.cc DEPS ${analysis_deps}
fluid_to_data_flow_graph_pass.cc )
data_flow_graph_to_fluid_pass.cc
dfg_graphviz_draw_pass.cc
tensorrt_subgraph_pass.cc
tensorrt_subgraph_node_mark_pass.cc
fluid_to_ir_pass.cc
model_store_pass.cc
DEPS ${analysis_deps})
cc_test(test_node SRCS node_tester.cc DEPS analysis)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis) cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis paddle_fluid)
function(inference_analysis_test TARGET) function(inference_analysis_test TARGET)
if(WITH_TESTING) if(WITH_TESTING)
...@@ -34,13 +35,3 @@ function(inference_analysis_test TARGET) ...@@ -34,13 +35,3 @@ function(inference_analysis_test TARGET)
endfunction(inference_analysis_test) endfunction(inference_analysis_test)
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api) inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
inference_analysis_test(test_fluid_to_ir_pass SRCS fluid_to_ir_pass_tester.cc)
inference_analysis_test(test_fluid_to_data_flow_graph_pass SRCS fluid_to_data_flow_graph_pass_tester.cc)
inference_analysis_test(test_subgraph_splitter SRCS subgraph_splitter_tester.cc)
inference_analysis_test(test_dfg_graphviz_draw_pass SRCS dfg_graphviz_draw_pass_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass_tester.cc)
inference_analysis_test(test_pass_manager SRCS pass_manager_tester.cc)
inference_analysis_test(test_tensorrt_subgraph_node_mark_pass SRCS tensorrt_subgraph_node_mark_pass_tester.cc)
inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc)
...@@ -19,42 +19,36 @@ limitations under the License. */ ...@@ -19,42 +19,36 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/inference/analysis/argument.h" #include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/node.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
/*
* AnalysisPass is a pass used to control the IR passes.
*/
class AnalysisPass { class AnalysisPass {
public: public:
AnalysisPass() = default; AnalysisPass() = default;
virtual ~AnalysisPass() = default; virtual ~AnalysisPass() = default;
// Mutable Pass.
virtual bool Initialize(Argument *argument) { return false; }
// Readonly Pass.
virtual bool Initialize(const Argument &argument) { return false; }
// Virtual method overriden by subclasses to do any necessary clean up after // Run on a single Graph.
// all passes have run. void Run(Argument* argument) { RunImpl(argument); }
virtual bool Finalize() { return false; }
// Create a debugger Pass that draw the DFG by graphviz toolkit.
virtual AnalysisPass *CreateGraphvizDebugerPass() const { return nullptr; }
// Run on a single DataFlowGraph.
virtual void Run(DataFlowGraph *x) = 0;
// Human-readable short representation. // Human-readable short representation.
virtual std::string repr() const = 0; virtual std::string repr() const = 0;
// Human-readable long description. // Human-readable long description.
virtual std::string description() const { return "No DOC"; } virtual std::string description() const { return "No DOC"; }
};
// GraphPass processes on any GraphType. protected:
class DataFlowGraphPass : public AnalysisPass {}; // User should implement these.
virtual void RunImpl(Argument* argument) = 0;
Argument* argument_{nullptr};
};
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -15,138 +15,23 @@ ...@@ -15,138 +15,23 @@
#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/analyzer.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/passes/passes.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_node_mark_pass.h"
#include "paddle/fluid/inference/analysis/tensorrt_subgraph_pass.h"
DEFINE_bool(IA_enable_tensorrt_subgraph_engine, false,
"Enable subgraph to TensorRT engine for acceleration");
DEFINE_bool(IA_enable_ir, false, "Turn on IR support");
DEFINE_string(IA_graphviz_log_root, "./",
"Graphviz debuger for data flow graphs.");
DEFINE_string(IA_output_storage_path, "", "optimized model output path");
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class DfgPassManagerImpl final : public DfgPassManager { Analyzer::Analyzer() {}
public:
DfgPassManagerImpl() {
// TODO(Superjomn) set the key with pass reprs.
if (!FLAGS_IA_enable_ir) {
AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass);
} else {
AddPass("fluid-to-ir-pass", new FluidToIrPass);
}
TryAddTensorRtPass();
AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass);
if (!FLAGS_IA_output_storage_path.empty()) {
AddPass("model-store-pass", new ModelStorePass);
}
}
std::string repr() const override { return "dfg-pass-manager"; }
std::string description() const override { return "DFG pass manager."; }
private:
void AddPass(const std::string& name, AnalysisPass* pass) {
VLOG(30) << "Adding pass " << name;
Register(name, pass);
AddGraphvizDebugerPass(pass);
}
void TryAddTensorRtPass() {
if (FLAGS_IA_enable_tensorrt_subgraph_engine) {
auto trt_teller = [&](const Node* node) {
std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"});
if (!node->IsFunction()) return false;
const auto* func = static_cast<const Function*>(node); void Analyzer::Run(Argument *argument) { RunIrAnalysis(argument); }
if (teller_set.count(func->func_type())) {
return true;
} else {
return false;
}
};
AddPass("tensorrt-subgraph-marker", void Analyzer::RunIrAnalysis(Argument *argument) {
new TensorRTSubgraphNodeMarkPass(trt_teller)); std::vector<std::string> passes({"ir_analysis_compose_pass"});
AddPass("tensorrt-subgraph", new TensorRTSubGraphPass(trt_teller));
}
}
// Add the graphviz debuger pass if the parent pass has one. for (auto &pass : passes) {
void AddGraphvizDebugerPass(AnalysisPass* pass) { PassRegistry::Global().Retreive(pass)->Run(argument);
auto* debuger_pass = pass->CreateGraphvizDebugerPass();
if (debuger_pass) {
Register(debuger_pass->repr(), debuger_pass);
}
} }
};
Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); }
void Analyzer::Run(Argument* argument) {
std::vector<std::string> passes;
passes.push_back("graph_viz_pass"); // add graphviz for debug.
#ifdef PADDLE_WITH_MKLDNN
if (use_mkldnn_) {
VLOG(30) << "Adding MKL-DNN placement pass";
passes.push_back("mkldnn_placement_pass");
}
#endif
// infer_clean_graph_pass should be the first default pass
// after mkldnn_placement_pass.
passes.push_back("infer_clean_graph_pass");
passes.push_back("graph_viz_pass"); // add graphviz for debug.
for (auto& pass : ir_passes_) {
// skip mkldnn pass when use_mkldnn_ = false;
bool skip_pass = (!use_mkldnn_) && pass.find("mkldnn") != std::string::npos;
if (!disabled_ir_passes_.count(pass) && !skip_pass) {
passes.push_back(pass);
passes.push_back("graph_viz_pass"); // add graphviz for debug.
}
}
argument->Set(kFluidToIrPassesAttr, new std::vector<std::string>(passes));
for (auto& x : data_) {
PADDLE_ENFORCE(x->Initialize(argument));
x->RunAll();
PADDLE_ENFORCE(x->Finalize());
}
}
Analyzer& Analyzer::IncludeAllIrPasses() {
ir_passes_ = all_ir_passes_;
return *this;
}
Analyzer& Analyzer::DisableIrPasses(const std::vector<std::string>& passes) {
disabled_ir_passes_.insert(passes.begin(), passes.end());
return *this;
}
Analyzer& Analyzer::IncludeIrPasses(const std::vector<std::string>& passes) {
ir_passes_ = passes;
return *this;
}
Analyzer& Analyzer::SetUseMkldnn(bool use_mkldnn) {
use_mkldnn_ = use_mkldnn;
return *this;
} }
} // namespace analysis } // namespace analysis
......
...@@ -40,56 +40,21 @@ limitations under the License. */ ...@@ -40,56 +40,21 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class Analyzer : public OrderedRegistry<PassManager> { class Analyzer final {
public: public:
// Register all the pass-managers.
Analyzer(); Analyzer();
void Run(Argument* argument); void Run(Argument* argument);
Analyzer& DisableIrPasses(const std::vector<std::string>& passes);
Analyzer& IncludeIrPasses(const std::vector<std::string>& passes);
Analyzer& IncludeAllIrPasses();
Analyzer& SetUseMkldnn(bool use_mkldnn);
DISABLE_COPY_AND_ASSIGN(Analyzer); DISABLE_COPY_AND_ASSIGN(Analyzer);
private: protected:
// All avaiable IR passes. void RunIrAnalysis(Argument* argument);
// The bigger fuse comes first, so that the small operators prefer to be
// merged in a larger fuse op. The small fusion will not break the pattern of
// larger fusion.
const std::vector<std::string> all_ir_passes_{{
// Manual update the passes here.
"attention_lstm_fuse_pass", //
"seqconv_eltadd_relu_fuse_pass", //
"embedding_fc_lstm_fuse_pass", //
"fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", //
"mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", //
"fc_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
#ifdef PADDLE_WITH_MKLDNN
"depthwise_conv_mkldnn_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", //
#endif
}};
std::unordered_set<std::string> disabled_ir_passes_;
// Ir passes to run
std::vector<std::string> ir_passes_;
bool use_mkldnn_;
}; };
} // namespace analysis } // namespace analysis
......
...@@ -27,21 +27,21 @@ namespace analysis { ...@@ -27,21 +27,21 @@ namespace analysis {
using namespace framework; // NOLINT using namespace framework; // NOLINT
TEST(Analyzer, analysis_without_tensorrt) { TEST(Analyzer, analysis_without_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = false;
Argument argument; Argument argument;
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir)); argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
TEST(Analyzer, analysis_with_tensorrt) { TEST(Analyzer, analysis_with_tensorrt) {
FLAGS_IA_enable_tensorrt_subgraph_engine = true;
Argument argument; Argument argument;
argument.Set<int>("minimum_subgraph_size", new int(0)); argument.SetTensorRtMaxBatchSize(3);
argument.Set<int>("max_batch_size", new int(3)); argument.SetTensorRtWorkspaceSize(1 << 20);
argument.Set<int>("workspace_size", new int(1 << 20)); argument.SetModelDir(FLAGS_inference_model_dir);
argument.Set<std::string>("precision_mode", new std::string("FP32")); argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.fluid_model_dir.reset(new std::string(FLAGS_inference_model_dir));
Analyzer analyser; Analyzer analyser;
analyser.Run(&argument); analyser.Run(&argument);
} }
......
...@@ -24,13 +24,16 @@ ...@@ -24,13 +24,16 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using framework::ir::Graph;
/* /*
* The argument definition of both Pass and PassManagers. * The argument definition of both Pass and PassManagers.
...@@ -39,75 +42,99 @@ namespace analysis { ...@@ -39,75 +42,99 @@ namespace analysis {
*/ */
struct Argument { struct Argument {
Argument() = default; Argument() = default;
explicit Argument(const std::string& fluid_model_dir) explicit Argument(const std::string& model_dir) { SetModelDir(model_dir); }
: fluid_model_dir(new std::string(fluid_model_dir)) {}
// The directory of the trained model. using unique_ptr_t = std::unique_ptr<void, std::function<void(void*)>>;
std::unique_ptr<std::string> fluid_model_dir; using fusion_statis_t = std::unordered_map<std::string, int>;
// The path of `__model__` and `param`, this is used when the file name of
// model and param is changed. bool Has(const std::string& key) const { return valid_fields_.count(key); }
std::unique_ptr<std::string> fluid_model_program_path;
std::unique_ptr<std::string> fluid_model_param_path; #define DECL_ARGUMENT_FIELD(field__, Field, type__) \
public: \
// The graph that process by the Passes or PassManagers. type__& field__() { \
std::unique_ptr<DataFlowGraph> main_dfg; PADDLE_ENFORCE(Has(#field__)); \
return field__##_; \
// The original program desc. } \
std::unique_ptr<framework::proto::ProgramDesc> origin_program_desc; void Set##Field(const type__& x) { \
field__##_ = x; \
// The processed program desc. valid_fields_.insert(#field__); \
std::unique_ptr<framework::proto::ProgramDesc> transformed_program_desc; } \
DECL_ARGUMENT_FIELD_VALID(field__); \
// The output storage path of ModelStorePass. type__* field__##_ptr() { return &field__##_; } \
std::unique_ptr<std::string> model_output_store_path; \
private: \
// Support for any other attributes. type__ field__##_;
template <typename T>
void Set(const std::string& key, T* data) { #define DECL_ARGUMENT_FIELD_VALID(field__) \
PADDLE_ENFORCE_NOT_NULL(data); bool field__##_valid() { return Has(#field__); }
PADDLE_ENFORCE(!attrs_.count(key), "Duplicate set Argument's attr [%s]",
key); #define DECL_ARGUMENT_UNIQUE_FIELD(field__, Field, type__) \
attrs_[key] = data; public: \
attr_deleters_[key] = [data, key]() { type__& field__() { \
VLOG(30) << "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; PADDLE_ENFORCE_NOT_NULL(field__##_); \
VLOG(30) << "argument delete attr: " << key; PADDLE_ENFORCE(Has(#field__)); \
delete data; return *static_cast<type__*>(field__##_.get()); \
}; } \
} void Set##Field(type__* x) { \
field__##_ = \
bool Has(const std::string& name) const { return attrs_.count(name); } unique_ptr_t(x, [](void* x) { delete static_cast<type__*>(x); }); \
valid_fields_.insert(#field__); \
template <typename T> } \
T* Release(const std::string& key) { void Set##Field##NotOwned(type__* x) { \
PADDLE_ENFORCE(attrs_.count(key)); valid_fields_.insert(#field__); \
auto* res = boost::any_cast<T*>(attrs_.at(key)); field__##_ = unique_ptr_t(x, [](void* x) {}); \
attrs_.erase(key); } \
attr_deleters_.erase(key); DECL_ARGUMENT_FIELD_VALID(field__); \
return res; type__* field__##_ptr() { \
} PADDLE_ENFORCE(Has(#field__)); \
return static_cast<type__*>(field__##_.get()); \
template <typename T> } \
T& Get(const std::string& key) { type__* Release##Field() { \
PADDLE_ENFORCE(Has(key)); PADDLE_ENFORCE(Has(#field__)); \
return *boost::any_cast<T*>(attrs_.at(key)); valid_fields_.erase(#field__); \
} return static_cast<type__*>(field__##_.release()); \
} \
~Argument() { \
for (auto& item : attr_deleters_) { private: \
item.second(); unique_ptr_t field__##_;
}
} // Model path
DECL_ARGUMENT_FIELD(model_dir, ModelDir, std::string);
// Model specified with program and parameters files.
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
// The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
// The overall Scope to work on.
DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);
DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);
// The ir passes to perform in analysis phase.
DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
DECL_ARGUMENT_FIELD(tensorrt_node_teller, TensorRtNodeTeller,
std::function<bool(const framework::ir::Node*)>);
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
// The program transformed by IR analysis phase.
DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram,
framework::proto::ProgramDesc);
DECL_ARGUMENT_FIELD(fusion_statis, FusionStatis, fusion_statis_t);
private: private:
std::unordered_map<std::string, boost::any> attrs_; std::unordered_set<std::string> valid_fields_;
std::unordered_map<std::string, std::function<void()>> attr_deleters_;
}; };
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define ARGUMENT_CHECK_FIELD(argument__, fieldname__) \
#define ANALYSIS_ARGUMENT_CHECK_FIELD(field__) \ PADDLE_ENFORCE(argument__->Has(#fieldname__), \
if (UNLIKELY(!(field__))) { \ "the argument field [%s] should be set", #fieldname__);
LOG(ERROR) << "field " << #field__ << " should be set."; \
return false; \
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/node.h"
namespace paddle {
namespace inference {
namespace analysis {
using ir_node_t = framework::ir::Node;
using ir_graph_t = framework::ir::Graph;
// It is a better idea that the inputs and outputs of this graph is set manually
// before, but there must be a Pass that helps to prune the unnecessary ops that
// do not contribute to the given targets, so in this pass, analysis and get the
// inputs and outputs is OK.
void DataFlowGraph::Build() {
inputs_.clear();
outputs_.clear();
std::unordered_set<Node *> ins;
std::unordered_set<Node *> outs;
for (auto &node : nodes.nodes()) {
for (auto *in : node->inlinks) {
ins.insert(in);
}
for (auto *out : node->outlinks) {
outs.insert(out);
}
}
// The nodes that in ins but not in outs is the graph's inputs
// similarly, the nodes that in outs but not in ins is the graphs' outputs
for (auto *in : ins) {
if (!outs.count(in)) {
inputs_.push_back(in);
}
}
for (auto *out : outs) {
if (!ins.count(out)) {
outputs_.push_back(out);
}
}
Clean();
}
void DataFlowGraph::Build(const framework::proto::ProgramDesc &prog) {
// insert vars
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
// will keep updating to its latest alias during the graph-building.
std::unordered_map<std::string, size_t> var2id;
auto &main_block = prog.blocks(framework::kRootBlockIndex);
for (int i = 0; i < main_block.vars_size(); i++) {
const auto &var = main_block.vars(i);
auto *v = nodes.Create(Node::Type::kValue);
v->SetName(var.name());
v->SetPbDesc(const_cast<void *>(static_cast<const void *>(&var)));
v->SetPbMsg(var.SerializeAsString());
var2id[var.name()] = v->id();
}
// The variables in a SSA can only write once, so if a variable is written
// multiple times(quite common in our ProgramDesc design), multiple alias
// Nodes of this variable will be created, and each will just write once.
// An set that keep all the names of the variables(the original, not alias)
// that have been written(as outputs). Once an Op's output variable hit the
// set, it should create a new alias and update the global alias for this
// variable. And that make a Data Flow Graph a SSA.
std::unordered_set<Node *> unique_written_vars;
for (int i = 0; i < main_block.ops_size(); i++) {
const auto &op = main_block.ops(i);
auto *o = nodes.Create(Node::Type::kFunction);
o->SetName(op.type());
static_cast<Function *>(o)->SetFuncType(op.type());
// Link to the original protobuf message's memory, make it easier to
// generate from a data flow graph to fluid ProgramDesc.
o->SetPbDesc(const_cast<void *>(static_cast<const void *>(&op)));
o->SetPbMsg(op.SerializeAsString());
// set inputs and outputs
for (int j = 0; j < op.inputs_size(); j++) {
auto &in_var = op.inputs(j);
for (int k = 0; k < in_var.arguments_size(); k++) {
auto *in = nodes.GetMutable(var2id.at(in_var.arguments(k)));
in->outlinks.push_back(o);
o->inlinks.push_back(in);
unique_written_vars.insert(in);
}
}
for (int j = 0; j < op.outputs_size(); j++) {
auto &out_var = op.outputs(j);
for (int k = 0; k < out_var.arguments_size(); k++) {
auto *out = nodes.GetMutable(var2id[out_var.arguments(k)]);
if (unique_written_vars.count(out)) {
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
auto *out_alias = nodes.Create(Node::Type::kValue);
out_alias->SetName(out->name());
out_alias->SetPbDesc(out->pb_desc());
out_alias->SetPbMsg(out->pb_msg());
var2id[out_alias->name()] =
out_alias->id(); // update variable's alias Node
VLOG(40) << "loop found in graph, create SSA alias node ["
<< out_alias->repr() << "] for [" << out->repr() << "]";
out = out_alias;
}
out->inlinks.push_back(o);
o->outlinks.push_back(out);
}
}
}
// Analysis and extract the inputs and outputs of this graph.
Build();
}
void DataFlowGraph::Build(const framework::ir::Graph &graph) {
// Create nodes
std::unordered_map<ir_node_t *, Node *> ir_node_map;
for (auto *ir_node : graph.Nodes()) {
Node *x{nullptr};
if (ir_node->IsOp()) {
PADDLE_ENFORCE(ir_node->Op());
VLOG(40) << "get op " << ir_node << " " << ir_node->Name();
x = nodes.Create(Node::Type::kFunction);
x->attr("ir_node").Pointer() = ir_node;
PADDLE_ENFORCE(ir_node->Op()->Proto());
x->SetName(ir_node->Op()->Proto()->type());
x->SetPbMsg(ir_node->Op()->Proto()->SerializeAsString());
} else if (ir_node->IsVar()) {
// Not create a Node for IR ControlDepVar, considering Inference currently
// just used in single thread scenerio.
VLOG(40) << "get var " << ir_node->Name();
x = nodes.Create(Node::Type::kValue);
x->attr("ir_node").Pointer() = ir_node;
x->SetName(ir_node->Name());
// x->SetPbMsg(ir_node->Var()->Proto()->SerializeAsString());
} else {
PADDLE_THROW("Failed to create an Node from IR, unknown type");
}
ir_node_map.emplace(ir_node, x);
}
VLOG(40) << "finish creating Nodes";
VLOG(40) << "to create edge";
// Create links
for (auto *ir_node : graph.Nodes()) {
auto it = ir_node_map.find(ir_node);
// Skip ControlDepVar.
if (it == ir_node_map.end()) continue;
auto *node = it->second;
for (auto *x : ir_node->inputs) {
if (!ir_node_map.count(x)) continue;
node->inlinks.push_back(ir_node_map.at(x));
}
for (auto *x : ir_node->outputs) {
if (!ir_node_map.count(x)) continue;
node->outlinks.push_back(ir_node_map.at(x));
}
}
Build();
PADDLE_ENFORCE(!inputs_.empty(),
"Can't deduce any inputs from the graph, Is the graph empty?");
ir_graph = &graph;
VLOG(30) << "finished build from IR";
}
void DataFlowGraph::Clean() {
for (auto &node : nodes.nodes()) {
std::unordered_set<Node *> inlinks_set(node->inlinks.begin(),
node->inlinks.end());
std::unordered_set<Node *> outlinks_set(node->outlinks.begin(),
node->outlinks.end());
if (inlinks_set.size() < node->inlinks.size()) {
node->inlinks.assign(inlinks_set.begin(), inlinks_set.end());
}
if (outlinks_set.size() < node->outlinks.size()) {
node->outlinks.assign(outlinks_set.begin(), outlinks_set.end());
}
}
}
std::string DataFlowGraph::DotString() const {
Dot dot;
// Add nodes
for (size_t i = 0; i < nodes.size(); i++) {
const Node &node = nodes.Get(i);
dot.AddNode(node.repr(), node.dot_attrs());
}
// Add edges
for (size_t i = 0; i < nodes.size(); i++) {
const Node &node = nodes.Get(i);
for (auto &in : node.inlinks) {
dot.AddEdge(in->repr(), node.repr(), {});
}
}
return dot.Build();
}
std::string DataFlowGraph::HumanReadableInfo(bool show_values,
bool show_functions) const {
std::stringstream values, functions;
for (auto &n : nodes.nodes()) {
if (show_values && n->IsValue()) {
values << n->repr() << "\n";
}
if (show_functions && n->IsFunction()) {
functions << n->repr() << "\n";
}
}
return "Values:\n" + values.str() + "\n\n" + "Functions:\n" + functions.str();
}
//
// NodesBFSIterator
//
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
const std::vector<Node *> &source)
: queue_(source.begin(), source.end()) {}
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
GraphTraits<DataFlowGraph>::NodesBFSIterator &&other) noexcept
: queue_(std::move(other.queue_)),
visited_(std::move(other.visited_)) {}
GraphTraits<DataFlowGraph>::NodesBFSIterator::NodesBFSIterator(
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other)
: queue_(other.queue_), visited_(other.visited_) {}
Node &GraphTraits<DataFlowGraph>::NodesBFSIterator::operator*() {
PADDLE_ENFORCE(!queue_.empty());
return *queue_.front();
}
Node *GraphTraits<DataFlowGraph>::NodesBFSIterator::operator->() {
PADDLE_ENFORCE(!queue_.empty());
return queue_.front();
}
GraphTraits<DataFlowGraph>::NodesBFSIterator &
GraphTraits<DataFlowGraph>::NodesBFSIterator::operator=(
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
queue_ = other.queue_;
visited_ = other.visited_;
return *this;
}
GraphTraits<DataFlowGraph>::NodesBFSIterator
&GraphTraits<DataFlowGraph>::NodesBFSIterator::operator++() {
PADDLE_ENFORCE(!queue_.empty());
auto *cur = queue_.front();
visited_.insert(cur);
queue_.pop_front();
for (auto *output : cur->outlinks) {
if (!visited_.count(output)) {
queue_.push_back(output);
visited_.insert(output);
}
}
return *this;
}
bool GraphTraits<DataFlowGraph>::NodesBFSIterator::operator==(
const GraphTraits<DataFlowGraph>::NodesBFSIterator &other) {
if (queue_.empty()) return other.queue_.empty();
if ((!queue_.empty()) && (!other.queue_.empty())) {
return queue_.front() == other.queue_.front() &&
visited_.size() == other.visited_.size();
// equality of queue and
// visited. Just a light but week implementation.
}
return false;
}
//
// NodesDFSIterator
//
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
const std::vector<Node *> &source) {
for (auto *x : source) stack_.push(x);
}
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
GraphTraits<DataFlowGraph>::NodesDFSIterator &&other) noexcept
: stack_(std::move(other.stack_)),
visited_(std::move(other.visited_)) {}
GraphTraits<DataFlowGraph>::NodesDFSIterator::NodesDFSIterator(
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other)
: stack_(other.stack_), visited_(other.visited_) {}
Node &GraphTraits<DataFlowGraph>::NodesDFSIterator::operator*() {
PADDLE_ENFORCE(!stack_.empty());
return *stack_.top();
}
GraphTraits<DataFlowGraph>::NodesDFSIterator
&GraphTraits<DataFlowGraph>::NodesDFSIterator::operator++() {
if (stack_.empty()) return *this;
visited_.insert(stack_.top());
auto *cur = stack_.top();
stack_.pop();
for (auto *x : cur->outlinks) {
if (!visited_.count(x)) {
stack_.push(x);
visited_.insert(x);
}
}
return *this;
}
bool GraphTraits<DataFlowGraph>::NodesDFSIterator::operator==(
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
if (stack_.empty()) return other.stack_.empty();
if ((!stack_.empty()) && (!other.stack_.empty())) {
return stack_.top() == other.stack_.top();
}
return false;
}
GraphTraits<DataFlowGraph>::NodesDFSIterator &
GraphTraits<DataFlowGraph>::NodesDFSIterator::operator=(
const GraphTraits<DataFlowGraph>::NodesDFSIterator &other) {
stack_ = other.stack_;
visited_ = other.visited_;
return *this;
}
Node *GraphTraits<DataFlowGraph>::NodesDFSIterator::operator->() {
return stack_.top();
}
inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
return node.inlinks.size() == n;
}
GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
const std::vector<Node *> &source) {
PADDLE_ENFORCE(!source.empty(),
"Start points of topological sorting should not be empty!");
// CHECK all the inputs' in-degree is 0
for (auto *node : source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
}
std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited;
while (!to_visit.empty()) {
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
for (auto *p : queue) {
if (p->deleted()) {
visited.insert(p);
to_visit.erase(p);
continue;
}
inlink_visited.clear();
std::copy_if(p->inlinks.begin(), p->inlinks.end(),
std::back_inserter(inlink_visited),
[&](Node *x) { return visited.count(x); });
if (inlink_visited.size() == p->inlinks.size()) {
sorted_.push_back(p);
for (auto *_ : p->outlinks) {
if (!visited.count(_)) {
to_visit.insert(_);
}
}
to_visit.erase(p);
visited.insert(p);
}
}
}
}
GraphTraits<DataFlowGraph>::NodesTSIterator::NodesTSIterator(
const paddle::inference::analysis::GraphTraits<
DataFlowGraph>::NodesTSIterator &other)
: sorted_(other.sorted_), cursor_(other.cursor_) {}
Node &GraphTraits<DataFlowGraph>::NodesTSIterator::operator*() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
return *sorted_[cursor_];
}
paddle::inference::analysis::GraphTraits<DataFlowGraph>::NodesTSIterator
&GraphTraits<DataFlowGraph>::NodesTSIterator::operator++() {
if (++cursor_ >= sorted_.size()) {
sorted_.clear();
cursor_ = 0;
}
return *this;
}
paddle::inference::analysis::GraphTraits<DataFlowGraph>::NodesTSIterator &
GraphTraits<DataFlowGraph>::NodesTSIterator::operator=(
const paddle::inference::analysis::GraphTraits<
DataFlowGraph>::NodesTSIterator &other) {
cursor_ = other.cursor_;
sorted_ = other.sorted_;
return *this;
}
bool GraphTraits<DataFlowGraph>::NodesTSIterator::operator==(
const paddle::inference::analysis::GraphTraits<
DataFlowGraph>::NodesTSIterator &other) {
return sorted_ == other.sorted_ && cursor_ == other.cursor_;
}
Node *GraphTraits<DataFlowGraph>::NodesTSIterator::operator->() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
return sorted_[cursor_];
}
std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
std::unordered_set<Node *> inputs;
std::unordered_set<Node *> outputs;
// Input a Value, check whether its inlink is in the subgraph.
auto inlink_in_subgraph = [&](Node *n) {
for (auto *in : n->inlinks) {
if (nodes.count(in)) return true;
}
return false;
};
for (auto &node : graph) {
for (auto *in : node->inlinks) {
// The Value that is written by nodes inside a sub-graph shouldn't be the
// input of the sub-graph.
if (!nodes.count(in) && in->type() == Node::Type::kValue &&
!inlink_in_subgraph(in)) {
inputs.insert(in);
}
}
for (auto *out : node->outlinks) {
if (!nodes.count(out) && out->type() == Node::Type::kValue) {
outputs.insert(out);
}
}
}
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
std::vector<Node *>(outputs.begin(), outputs.end()));
}
// Filter the Intermediate results of the subgraph node.
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
std::vector<Node *> op_nodes;
for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) {
if (node.type() == Node::Type::kValue || node.deleted()) {
continue;
}
op_nodes.push_back(&node);
}
size_t op_num = op_nodes.size();
for (size_t i = 0; i < op_num; i++) {
if (op_nodes[i]->type() == Node::Type::kFunction) continue;
std::unordered_set<std::string> follow_up_input_names;
for (size_t j = i + 1; j < op_num; j++) {
for (auto *in : op_nodes[j]->inlinks) {
follow_up_input_names.insert(in->name());
}
}
std::vector<Node *> filtered_subgraph_outlinks;
for (auto *out : op_nodes[i]->outlinks) {
if (follow_up_input_names.count(out->name())) {
filtered_subgraph_outlinks.push_back(out);
} else {
out->SetDeleted();
}
}
// The filtered_subgraph_outlinks may be empty.
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* Data flow graph is an pass that build the basic graph. It contains a graph
* and the iterators that enable the iteration over the graph.
*/
#pragma once
#include <deque>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* DataFlowGraph - A container of Value and Function Nodes.
*
* This is the base graph for any other type of graphs, such as SSA or CFG.
*/
struct DataFlowGraph {
NodeMap nodes;
// inputs and outputs are deduced from the graph.
// Used to interact with IR.
const framework::ir::Graph *ir_graph{nullptr};
// Extract inputs and outputs of the graph.
void Build();
void Build(const framework::proto::ProgramDesc &prog);
// Build a graph from ir::Graph.
void Build(const framework::ir::Graph &graph);
// Get an attribute.
AnyAttr &Attr(const std::string &key) { return attrs_[key]; }
// Output a DOT graph file for debug.
std::string DotString() const;
std::string HumanReadableInfo(bool show_values = true,
bool show_functions = true) const;
const std::vector<Node *> &inputs() const {
PADDLE_ENFORCE(!inputs_.empty(),
"No inputs are deduced, need to Build() first.");
return inputs_;
}
const std::vector<Node *> &outputs() const {
PADDLE_ENFORCE(!outputs_.empty(),
"No outputs are deduced, need to Build() first.");
return outputs_;
}
private:
mutable std::vector<Node *> inputs_;
mutable std::vector<Node *> outputs_;
std::unordered_map<std::string, AnyAttr> attrs_;
// Remove duplicate edges and so on.
void Clean();
};
/*
* An graph trait help to traverse the graph using BFS.
* The BFS start from a graph's inputs, the graph should be fully-connected, so
* that the iterator can reach the end.
*/
template <>
struct GraphTraits<DataFlowGraph> {
// BFS iterator on nodes.
struct NodesBFSIterator
: public std::iterator<std::forward_iterator_tag, Node *> {
NodesBFSIterator() = default;
explicit NodesBFSIterator(const std::vector<Node *> &source);
NodesBFSIterator(NodesBFSIterator &&other) noexcept;
// NOTE Heavy to use.
NodesBFSIterator(const NodesBFSIterator &other);
Node &operator*();
NodesBFSIterator &operator++();
Node *operator->();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesBFSIterator &operator=(const NodesBFSIterator &other);
bool operator==(const NodesBFSIterator &other);
bool operator!=(const NodesBFSIterator &other) { return !(*this == other); }
private:
std::deque<Node *> queue_;
std::unordered_set<Node *> visited_;
};
// DFS iterator on nodes.
struct NodesDFSIterator
: public std::iterator<std::forward_iterator_tag, Node *> {
NodesDFSIterator() = default;
NodesDFSIterator(const std::vector<Node *> &source);
NodesDFSIterator(NodesDFSIterator &&other) noexcept;
NodesDFSIterator(const NodesDFSIterator &other);
Node &operator*();
NodesDFSIterator &operator++();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesDFSIterator &operator=(const NodesDFSIterator &other);
bool operator==(const NodesDFSIterator &other);
bool operator!=(const NodesDFSIterator &other) { return !(*this == other); }
Node *operator->();
private:
std::stack<Node *> stack_;
std::unordered_set<Node *> visited_;
};
// Topological sorting iterator on nodes.
struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, Node *> {
NodesTSIterator() = default;
NodesTSIterator(const std::vector<Node *> &source);
NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0;
}
NodesTSIterator(const NodesTSIterator &other);
Node &operator*();
NodesTSIterator &operator++();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesTSIterator &operator=(const NodesTSIterator &other);
bool operator==(const NodesTSIterator &other);
bool operator!=(const NodesTSIterator &other) { return !(*this == other); }
Node *operator->();
private:
std::vector<Node *> sorted_;
size_t cursor_{0};
};
explicit GraphTraits(const DataFlowGraph &graph) : graph_(graph) {}
// default use BFS to visit the nodes.
iterator_range<NodesBFSIterator> nodes() {
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
}
iterator_range<NodesBFSIterator> nodes_in_BFS() {
return iterator_range<NodesBFSIterator>(nodes_bfs_begin(), nodes_bfs_end());
}
iterator_range<NodesDFSIterator> nodes_in_DFS() {
return iterator_range<NodesDFSIterator>(nodes_dfs_begin(), nodes_dfs_end());
}
iterator_range<NodesTSIterator> nodes_in_TS() {
return iterator_range<NodesTSIterator>(nodes_ts_begin(), nodes_ts_end());
}
private:
NodesBFSIterator nodes_bfs_begin() {
return NodesBFSIterator(graph_.inputs());
}
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
NodesDFSIterator nodes_dfs_begin() {
return NodesDFSIterator(graph_.inputs());
}
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_.inputs()); }
NodesTSIterator nodes_ts_end() { return NodesTSIterator(); }
private:
const DataFlowGraph &graph_;
};
// Extract the inputs and outputs of a graph. The inputs and outputs of a
// sub-graph is the inputs nodes and output nodes that doesn't inside the
// sub-graph.
std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph); // NOLINT
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph);
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST(DataFlowGraph, BFS) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
auto dfg = ProgramDescToDFG(desc);
dfg.Build();
for (auto* in : dfg.inputs()) {
LOG(INFO) << "inputs: " << in->name() << " "
<< static_cast<int>(in->type());
}
for (auto* out : dfg.outputs()) {
LOG(INFO) << "outputs: " << out->name() << " "
<< static_cast<int>(out->type());
}
size_t count = 0;
for (auto& node : GraphTraits<DataFlowGraph>(dfg).nodes()) {
LOG(INFO) << "visiting " << node.name();
++count;
}
ASSERT_EQ(count, dfg.nodes.size());
}
TEST(DataFlowGraph, DFS) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
DataFlowGraph dfg;
dfg.Build(desc);
size_t count = 0;
for (auto& node : GraphTraits<DataFlowGraph>(dfg).nodes_in_DFS()) {
LOG(INFO) << "visiting " << node.name();
++count;
}
ASSERT_EQ(count, dfg.nodes.size());
}
// Topological sorting.
/*
* Graph topology
* inputs: 0, 1, 2
* 0 -> 4
* 0 -> 5
* 1 -> 6
* 2 -> 7
* 4 -> 5
* 4 -> 7
* 4 -> 3
* 7 -> 3
*/
TEST(DataFlowGraph, TS) {
DataFlowGraph graph;
for (int i = 0; i < 8; i++) {
auto* node = graph.nodes.Create(Node::Type::kValue);
node->SetName("node-" + std::to_string(i));
}
auto add_link = [&](int i, int j) {
Node* source = graph.nodes.GetMutable(i);
Node* target = graph.nodes.GetMutable(j);
target->inlinks.push_back(source);
source->outlinks.push_back(target);
};
add_link(0, 4);
add_link(0, 5);
add_link(1, 6);
add_link(2, 7);
add_link(4, 5);
add_link(4, 7);
add_link(4, 3);
add_link(7, 3);
graph.Build();
auto its = GraphTraits<DataFlowGraph>(graph).nodes_in_TS();
std::vector<int> sorted_ids;
for (auto it = its.begin(); it != its.end(); ++it) {
LOG(INFO) << it->name();
sorted_ids.push_back(it->id());
}
// Assert a occurs prior to b in the sorted_ids.
auto assert_positive_sequence_pair = [&](int a, int b) {
auto a_offset = std::find(sorted_ids.begin(), sorted_ids.end(), a);
auto b_offset = std::find(sorted_ids.begin(), sorted_ids.end(), b);
ASSERT_LT(a_offset, b_offset);
};
assert_positive_sequence_pair(2, 7);
assert_positive_sequence_pair(7, 3);
assert_positive_sequence_pair(4, 3);
assert_positive_sequence_pair(0, 4);
assert_positive_sequence_pair(0, 5);
assert_positive_sequence_pair(1, 6);
assert_positive_sequence_pair(4, 5);
assert_positive_sequence_pair(4, 7);
}
TEST(DataFlowGraph, Build_ProgramDesc) {
auto desc = LoadProgramDesc(FLAGS_inference_model_dir + "/__model__");
DataFlowGraph graph;
graph.Build(desc);
ASSERT_EQ(graph.nodes.size(), 38UL);
}
void SetOp(framework::ProgramDesc* prog, const std::string& type,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("Xs", inputs);
op->SetOutput("Xs", outputs);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
}
TEST(DataFlowGraph, Build_IR_Graph) {
framework::ProgramDesc prog;
for (auto& v : std::vector<std::string>({"a", "b", "c", "d", "e", "f"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(framework::proto::VarType::SELECTED_ROWS);
if (v == "c") {
var->SetPersistable(true);
}
}
SetOp(&prog, "OP0", std::vector<std::string>({"a"}),
std::vector<std::string>({"b"}));
SetOp(&prog, "OP1", std::vector<std::string>({"a"}),
std::vector<std::string>({"c"}));
SetOp(&prog, "mul", std::vector<std::string>({"b", "c"}),
std::vector<std::string>({"d"}));
SetOp(&prog, "elementwise_add", std::vector<std::string>({"d", "e"}),
std::vector<std::string>({"f"}));
DataFlowGraph graph;
framework::ir::Graph ir_graph(prog);
graph.Build(ir_graph);
ASSERT_EQ(graph.nodes.size(), ir_graph.Nodes().size());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file implements the transformation from fluid ProgramDesc to data flow
* graph.
*/
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
namespace paddle {
namespace inference {
namespace analysis {
class DataFlowGraphToFluidPass final : public DataFlowGraphPass {
public:
DataFlowGraphToFluidPass() = default;
bool Initialize(Argument *argument) override;
bool Finalize() override;
void Run(DataFlowGraph *graph) override;
std::string repr() const override { return "DFG to fluid"; }
std::string description() const override {
return "Transform a DFG to a Fluid ProgramDesc";
}
AnalysisPass *CreateGraphvizDebugerPass() const override;
protected:
// Add a Fluid Op into the ProgramDesc.
void AddFluidOp(Node *node);
// Add a EngineOp into the ProgramDesc.
void AddEngineOp(Node *node);
private:
framework::proto::ProgramDesc *desc_;
Argument *argument_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include <glog/logging.h>
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/io.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST(DataFlowGraph, Test) {
Argument argument(FLAGS_inference_model_dir);
FluidToDataFlowGraphPass pass0;
DataFlowGraphToFluidPass pass1;
ASSERT_TRUE(pass0.Initialize(&argument));
ASSERT_TRUE(pass1.Initialize(&argument));
pass0.Run(argument.main_dfg.get());
pass1.Run(argument.main_dfg.get());
pass0.Finalize();
pass1.Finalize();
LOG(INFO) << argument.main_dfg->nodes.size();
}
}; // namespace analysis
}; // namespace inference
}; // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
int DFG_GraphvizDrawPass::counter_{0};
void DFG_GraphvizDrawPass::Run(DataFlowGraph *graph) {
auto content = Draw(graph);
auto dot_path = GenDotPath();
std::ofstream file(dot_path);
file.write(content.c_str(), content.size());
file.close();
auto png_path = dot_path.substr(0, dot_path.size() - 4) + ".png";
std::string message;
VLOG(30) << "draw to " << png_path;
ExecShellCommand("dot -Tpng " + dot_path + " -o " + png_path, &message);
}
std::string DFG_GraphvizDrawPass::Draw(DataFlowGraph *graph) {
Dot dot;
// Add nodes
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (config_.display_deleted_node || !node.deleted()) {
dot.AddNode(node.repr(), node.dot_attrs());
}
}
// Add edges
for (size_t i = 0; i < graph->nodes.size(); i++) {
const Node &node = graph->nodes.Get(i);
if (!config_.display_deleted_node && node.deleted()) continue;
for (auto &out : node.outlinks) {
if (!config_.display_deleted_node && out->deleted()) continue;
dot.AddEdge(node.repr(), out->repr(), {});
}
}
return dot.Build();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file create an DFG_GraphvizDrawPass which helps to draw a data flow
* graph's structure using graphviz.
*/
#pragma once
#include <fstream>
#include <string>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/dot.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Output a dot file and write to some place.
*/
class DFG_GraphvizDrawPass : public DataFlowGraphPass {
public:
struct Config {
Config(const std::string &dir, const std::string &id,
bool display_deleted_node = false)
: dir(dir), id(id), display_deleted_node(display_deleted_node) {}
// The directory to store the .dot or .png files.
const std::string dir;
// The identifier for this dot file.
const std::string id;
// Whether to display deleted nodes, default false.
const bool display_deleted_node;
};
explicit DFG_GraphvizDrawPass(const Config &config) : config_(config) {}
bool Initialize(Argument *argument) override { return true; }
void Run(DataFlowGraph *graph) override;
bool Finalize() override { return true; }
std::string repr() const override { return "DFG graphviz drawer"; }
std::string description() const override {
return "Debug a DFG by draw with graphviz";
}
protected:
// A counter to add a number prefix to the debugger image output so that they
// will sort in the triggered order.
static int counter_;
// Path of the dot file to output.
std::string GenDotPath() const {
return config_.dir + "/" + std::to_string(counter_++) + "-graph_" +
config_.id + ".dot";
}
virtual std::string Draw(DataFlowGraph *graph);
Config config_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include <gtest/gtest.h>
#include <fstream>
#include <string>
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
TEST(DFG_GraphvizDrawPass, dfg_graphviz_draw_pass_tester) {
Argument argument(FLAGS_inference_model_dir);
FluidToDataFlowGraphPass pass0;
ASSERT_TRUE(pass0.Initialize(&argument));
pass0.Run(argument.main_dfg.get());
// auto dfg = ProgramDescToDFG(*argument.origin_program_desc);
DFG_GraphvizDrawPass::Config config("./", "test");
DFG_GraphvizDrawPass pass(config);
pass.Initialize(&argument);
pass.Run(argument.main_dfg.get());
// test content
std::ifstream file("./0-graph_test.dot");
ASSERT_TRUE(file.is_open());
std::string line;
int no{0};
while (std::getline(file, line)) {
no++;
}
// DFG is sensitive to ProgramDesc, be careful to change the existing models.
ASSERT_EQ(no, 83);
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <glog/logging.h>
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
bool FluidToDataFlowGraphPass::Initialize(Argument *argument) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called";
}
if (!argument->fluid_model_program_path) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
argument->fluid_model_program_path.reset(
new std::string(*argument->fluid_model_dir + "/__model__"));
}
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
}
desc_ = argument->origin_program_desc.get();
return true;
}
bool FluidToDataFlowGraphPass::Finalize() { return true; }
void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
PADDLE_ENFORCE(graph);
PADDLE_ENFORCE(desc_);
graph->Build(*desc_);
}
namespace {
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
public:
using Config = DFG_GraphvizDrawPass::Config;
explicit DFG_DebuggerPass(const Config &config)
: DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "fluid-to-dfg-debuger-pass"; }
bool Finalize() override { return true; }
};
}
AnalysisPass *FluidToDataFlowGraphPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_IA_graphviz_log_root, "fluid-to-dfg-debuger"));
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file implements the transformation from data flow graph to fluid
* ProgramDesc.
*/
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* Transform a FluidDesc to a SSA.
*/
class FluidToDataFlowGraphPass final : public DataFlowGraphPass {
public:
FluidToDataFlowGraphPass() = default;
bool Initialize(Argument *argument) override;
bool Finalize() override;
void Run(DataFlowGraph *graph) override;
std::string repr() const override { return "fluid-to-data-flow-graph"; }
std::string description() const override {
return "transform a fluid ProgramDesc to a data flow graph.";
}
AnalysisPass *CreateGraphvizDebugerPass() const override;
private:
framework::proto::ProgramDesc const *desc_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/flags.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
namespace paddle {
namespace inference {
namespace analysis {
static const char kFluidToIrPassesAttr[] = "__fluid_to_ir_passes__";
class FluidToIrPass final : public DataFlowGraphPass {
public:
FluidToIrPass() = default;
bool Initialize(Argument *argument) override {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument);
PADDLE_ENFORCE(argument->Has(kFluidToIrPassesAttr),
"argument need the attr %s", kFluidToIrPassesAttr);
argument_ = argument;
if (argument->origin_program_desc) {
LOG(WARNING) << "argument's origin_program_desc is already set, might "
"duplicate called";
}
// set fluid model program path
if (!argument->fluid_model_program_path) {
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_dir);
argument->fluid_model_program_path.reset(
new std::string(*argument->fluid_model_dir + "/__model__"));
}
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->fluid_model_program_path);
// Load program.
auto program = LoadProgramDesc(*argument->fluid_model_program_path);
argument->origin_program_desc.reset(
new framework::proto::ProgramDesc(program));
// Create main data flow graph.
if (!argument->main_dfg) {
argument->main_dfg.reset(new DataFlowGraph);
}
argument->Set("ir_program_desc", new ProgramDesc(program));
LOG(INFO) << "Loading parameters";
// Load parameters to argument if needed.
if (argument->fluid_model_dir || (argument->fluid_model_program_path &&
argument->fluid_model_param_path)) {
#define SAFE_GET(ATTR) std::string ATTR = argument->ATTR ? *argument->ATTR : "";
SAFE_GET(fluid_model_dir);
SAFE_GET(fluid_model_program_path);
SAFE_GET(fluid_model_param_path);
#undef SAFE_GET
EnableParamModify(fluid_model_dir, fluid_model_program_path,
fluid_model_param_path);
}
return true;
}
bool Finalize() override { return true; }
void Run(DataFlowGraph *graph) override {
// Call all the IR Passes
IRPassManager ir_passes(argument_->Get<ProgramDesc>("ir_program_desc"),
nullptr);
// Pass the scope from analysis to IR if needed.
if (argument_->Has(framework::ir::kParamScopeAttr)) {
// Here the address is passed, attention that IR doesn't own the scope, so
// the real scope in analysis should live during the IR phase.
ir_passes.graph().Set(
framework::ir::kParamScopeAttr,
new framework::Scope *(&argument_->Get<framework::Scope>(
framework::ir::kParamScopeAttr)));
}
if (FLAGS_IA_enable_ir) {
const auto &ir_passes_to_apply =
argument_->Get<std::vector<std::string>>(kFluidToIrPassesAttr);
ir_passes.Apply(ir_passes_to_apply);
}
PADDLE_ENFORCE(argument_->main_dfg.get());
argument_->main_dfg->Build(ir_passes.graph());
// inherit the arguments from ir.
if (ir_passes.graph().Has(framework::ir::kFuseStatisAttr)) {
argument_->Set(
framework::ir::kFuseStatisAttr,
new std::unordered_map<std::string, int>(
ir_passes.graph().Get<std::unordered_map<std::string, int>>(
framework::ir::kFuseStatisAttr)));
}
}
void EnableParamModify(const std::string &model_dir,
const std::string &prog_file,
const std::string &param_file);
std::string repr() const override { return "fluid-to-ir-pass"; }
private:
// Load parameters from a single file or from a directory.
bool LoadParams(framework::Scope *scope, const std::string &dir,
const std::string &prog_file, const std::string &param_file);
private:
Argument *argument_{nullptr};
};
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/graph_traits.h"
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the GraphTraits<X> template class that should be specified
* by classes that want to be iteratable by generic graph iterators.
*
* This file also defines the marker class Inverse that is used to iterate over
* graphs in a graph defined, inverse ordering...
*/
#pragma once
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* This class should be specialized by different graph types...
* That's why the base class is empty.
*/
template <typename GraphType>
struct GraphTraits {
// using NodesBFSIterator = xxx
// NodesBFSIterator nodes_begin();
// NodesBFSIterator nodes_end();
};
/*
* Inverse - This class is used as a marker class to tell the graph iterator to
* iterate in a graph defined Inverse order.
*/
template <typename GraphType>
struct Inverse {
const GraphType &graph;
explicit Inverse(const GraphType &graph) : graph(graph) {}
};
/*
* Provide a partial specialization of GraphTraits so that the inverse of an
* inverse turns into the original graph.
*/
template <typename GraphType>
struct GraphTraits<Inverse<Inverse<GraphType>>> : GraphTraits<GraphType> {};
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -101,20 +101,20 @@ class OrderedRegistry { ...@@ -101,20 +101,20 @@ class OrderedRegistry {
public: public:
T *Register(const std::string &name, T *x) { T *Register(const std::string &name, T *x) {
PADDLE_ENFORCE(!dic_.count(name), "duplicate key [%s]", name); PADDLE_ENFORCE(!dic_.count(name), "duplicate key [%s]", name);
dic_[name] = data_.size(); dic_[name] = elements_.size();
data_.emplace_back(std::unique_ptr<T>(x)); elements_.emplace_back(std::unique_ptr<T>(x));
return data_.back().get(); return elements_.back().get();
} }
T *Lookup(const std::string &name) { T *Lookup(const std::string &name) {
auto it = dic_.find(name); auto it = dic_.find(name);
if (it == dic_.end()) return nullptr; if (it == dic_.end()) return nullptr;
return data_[it->second].get(); return elements_[it->second].get();
} }
protected: protected:
std::unordered_map<std::string, int> dic_; std::unordered_map<std::string, int> dic_;
std::vector<std::unique_ptr<T>> data_; std::vector<std::unique_ptr<T>> elements_;
}; };
template <typename T> template <typename T>
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -27,21 +29,33 @@ using string::PrettyLogEndl; ...@@ -27,21 +29,33 @@ using string::PrettyLogEndl;
using string::PrettyLog; using string::PrettyLog;
using string::Style; using string::Style;
IRPassManager::IRPassManager(const ProgramDesc &program, IRPassManager::IRPassManager(Argument *argument) {
framework::Scope *scope) ARGUMENT_CHECK_FIELD(argument, main_program);
: program_(program) { graph_ = std::unique_ptr<Graph>(new Graph(argument->main_program()));
graph_.reset(new framework::ir::Graph(program)); if (argument->Has("scope")) {
if (scope) graph_->Set(framework::ir::kParamScopeAttr,
graph_->Set(framework::ir::kParamScopeAttr, new framework::Scope *(scope)); new framework::Scope *(
const_cast<framework::Scope *>(&argument->scope())));
}
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
CreatePasses(argument, argument->ir_analysis_passes());
} }
void IRPassManager::Apply(const std::vector<std::string> &passes) { void IRPassManager::CreatePasses(Argument *argument,
// Apply all the passes const std::vector<std::string> &passes) {
std::string pre_pass; std::string pre_pass;
int pass_num = 0; int pass_num = 0;
for (const std::string &pass_name : passes) { for (const std::string &pass_name : passes) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
// Set some pass attributes.
if (pass_name == "ir_analysis_pass") {
pass->Set("tensorrt_node_teller",
new SubgraphDetector::NodeInsideSubgraphTeller(
argument->tensorrt_node_teller()));
}
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string dot_file_path = std::to_string(pass_num) + "_ir_" + std::string dot_file_path = std::to_string(pass_num) + "_ir_" +
(pre_pass.empty() ? "origin" : pre_pass) + (pre_pass.empty() ? "origin" : pre_pass) +
...@@ -49,11 +63,47 @@ void IRPassManager::Apply(const std::vector<std::string> &passes) { ...@@ -49,11 +63,47 @@ void IRPassManager::Apply(const std::vector<std::string> &passes) {
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++; pass_num++;
} }
graph_ = pass->Apply(std::move(graph_));
if (pass_name == "tensorrt_subgraph_pass") {
PADDLE_ENFORCE(argument->tensorrt_node_teller_valid());
pass->SetNotOwned("tensorrt_node_teller",
argument->tensorrt_node_teller_ptr());
pass->Set("workspace_size", new int(argument->tensorrt_workspace_size()));
pass->Set("max_batch_size", new int(argument->tensorrt_max_batch_size()));
}
// graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name; pre_pass = pass_name;
passes_.emplace_back(std::move(pass));
} }
} }
std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
if (passes_.empty()) {
return graph;
}
PADDLE_ENFORCE(graph.get());
// Apply all the passes
for (const auto &pass : passes_) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
graph = pass->Apply(std::move(graph));
}
return std::move(graph);
}
framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
ProgramDesc desc(program);
pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release();
*graph = pass->Apply(std::unique_ptr<Graph>(the_graph));
return *desc.Proto();
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -20,27 +20,38 @@ ...@@ -20,27 +20,38 @@
* for inference. * for inference.
*/ */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using framework::ProgramDesc; using framework::ProgramDesc;
using framework::ir::Graph;
class IRPassManager final { class IRPassManager final {
public: public:
IRPassManager(const ProgramDesc &program, framework::Scope *scope); explicit IRPassManager(Argument *argument);
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
void Apply(const std::vector<std::string> &passes); framework::proto::ProgramDesc AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const;
framework::ir::Graph &graph() const { return *graph_; } framework::ir::Graph &graph() const { return *graph_; }
private: private:
std::unique_ptr<framework::ir::Graph> graph_; void CreatePasses(Argument *argument, const std::vector<std::string> &passes);
ProgramDesc program_;
std::unique_ptr<Graph> graph_;
std::vector<std::unique_ptr<framework::ir::Pass>> passes_;
}; };
} // namespace analysis } // namespace analysis
......
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS proto_desc)
cc_library(tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector)
set(analysis_deps ${analysis_deps}
subgraph_detector tensorrt_subgraph_pass
CACHE INTERNAL "")
set(INFER_IR_PASSES ${INFER_IR_PASSES} tensorrt_subgraph_pass CACHE INTERNAL "")
...@@ -12,46 +12,110 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,46 +12,110 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/analysis/subgraph_splitter.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
const char *SubGraphSplitter::kMarkerAttrName = using framework::ir::Node;
"_sub_graph_splitter_inside_sub_graph";
std::vector<std::vector<Node *>> SubGraphSplitter::operator()() { std::pair<std::vector<Node *>, std::vector<Node *>>
ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
std::unordered_set<Node *> nodes(graph.begin(), graph.end());
std::unordered_set<Node *> inputs;
std::unordered_set<Node *> outputs;
// Input a Value, check whether its inlink is in the subgraph.
auto inlink_in_subgraph = [&](Node *n) {
for (auto *in : n->inputs) {
if (nodes.count(in)) return true;
}
return false;
};
for (auto &node : graph) {
for (auto *in : node->inputs) {
// The Value that is written by nodes inside a sub-graph shouldn't be the
// input of the sub-graph.
if (!nodes.count(in) && in->IsVar() && !inlink_in_subgraph(in)) {
inputs.insert(in);
}
}
for (auto *out : node->outputs) {
if (!nodes.count(out) && out->IsVar()) {
outputs.insert(out);
}
}
}
return std::make_pair(std::vector<Node *>(inputs.begin(), inputs.end()),
std::vector<Node *>(outputs.begin(), outputs.end()));
}
// Filter the Intermediate results of the subgraph node.
void FilterRedundantOutputOfSubGraph(Graph *graph) {
std::vector<Node *> op_nodes;
for (auto &node : TopologicalSort(*graph)) {
if (node.IsVar() || Agent(&node).deleted()) {
continue;
}
op_nodes.push_back(&node);
}
size_t op_num = op_nodes.size();
for (size_t i = 0; i < op_num; i++) {
if (op_nodes[i]->IsOp()) continue;
std::unordered_set<std::string> follow_up_input_names;
for (size_t j = i + 1; j < op_num; j++) {
for (auto *in : op_nodes[j]->inputs) {
follow_up_input_names.insert(in->Name());
}
}
std::vector<Node *> filtered_subgraph_outlinks;
for (auto *out : op_nodes[i]->outputs) {
if (follow_up_input_names.count(out->Name())) {
filtered_subgraph_outlinks.push_back(out);
} else {
Agent(out).set_deleted(true);
}
}
// The filtered_subgraph_outlinks may be empty.
op_nodes[i]->outputs = filtered_subgraph_outlinks;
}
}
std::vector<std::vector<Node *>> SubgraphDetector::operator()() {
MarkNodesInsideSubGraph(); MarkNodesInsideSubGraph();
return ExtractSubGraphs(); return ExtractSubGraphs();
} }
// Mark the output variables inside a subgraph with the func. // Mark the output variables inside a subgraph with the func.
inline void MarkOutLinksInSubGraph(const Function *func) { inline void MarkOutLinksInSubGraph(const Node *func) {
for (auto *var : func->outlinks) { for (auto *var : func->outputs) {
var->attr(SubGraphSplitter::kMarkerAttrName).Bool() = true; Agent(var).set_marked(true);
} }
} }
void SubGraphSplitter::MarkNodesInsideSubGraph() { void SubgraphDetector::MarkNodesInsideSubGraph() {
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes()) { for (auto &node : framework::ir::GraphTraits::DFS(*graph_)) {
if (node_inside_subgraph_teller_(&node)) { if (node_inside_subgraph_teller_(&node)) {
node.attr(kMarkerAttrName).Bool() = true; Agent(&node).set_marked(true);
if (node.type() == Node::Type::kFunction) { if (node.IsOp()) {
// If a function is inside the sub-graph, mark all the output variables // If a function is inside the sub-graph, mark all the output variables
// to be inside too, so that two marked functions will be inside a same // to be inside too, so that two marked functions will be inside a same
// sub-graph, lets take a example: A_function->var->B_function, if // sub-graph, lets take a example: A_function->var->B_function, if
// A_function is marked, var should also be marked, so that B_function // A_function is marked, var should also be marked, so that B_function
// will be in the same sub-graph with A_function if B_function is // will be in the same sub-graph with A_function if B_function is
// marked. // marked.
MarkOutLinksInSubGraph(static_cast<const Function *>(&node)); MarkOutLinksInSubGraph(&node);
} }
} }
} }
} }
const char *kUnionFindParent = "_sub_graph_splitter_union_find_parent_";
// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node // Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node
// a's output is node b, that is a and b is in the same sub-graph. The UF // a's output is node b, that is a and b is in the same sub-graph. The UF
// algorithm will group them to the same cluster. // algorithm will group them to the same cluster.
...@@ -60,8 +124,8 @@ using node_map_t = std::unordered_map<int, Node *>; ...@@ -60,8 +124,8 @@ using node_map_t = std::unordered_map<int, Node *>;
int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
int tmp = id; int tmp = id;
do { do {
tmp = node_map.at(tmp)->attr(kUnionFindParent).Int32(); tmp = Agent(node_map.at(tmp)).union_find_parent();
} while (node_map.at(tmp)->attr(kUnionFindParent).Int32() != tmp); } while (Agent(node_map.at(tmp)).union_find_parent() != tmp);
return tmp; return tmp;
} }
// Make this two node share the same ancestor. // Make this two node share the same ancestor.
...@@ -69,9 +133,9 @@ int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { ...@@ -69,9 +133,9 @@ int UnionFindGetAncestor(const node_map_t &node_map, size_t id) {
void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) {
int a_ancestor = UnionFindGetAncestor(node_map, a); int a_ancestor = UnionFindGetAncestor(node_map, a);
int b_ancestor = UnionFindGetAncestor(node_map, b); int b_ancestor = UnionFindGetAncestor(node_map, b);
node_map.at(b_ancestor)->attr(kUnionFindParent).Int32() = a_ancestor; Agent(node_map.at(b_ancestor)).set_union_find_parent(a_ancestor);
node_map.at(a)->attr(kUnionFindParent).Int32() = a_ancestor; Agent(node_map.at(a)).set_union_find_parent(a_ancestor);
node_map.at(b)->attr(kUnionFindParent).Int32() = a_ancestor; Agent(node_map.at(b)).set_union_find_parent(a_ancestor);
} }
// This is a simple representation of a graph. // This is a simple representation of a graph.
...@@ -195,16 +259,21 @@ void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse, ...@@ -195,16 +259,21 @@ void FlexibleDFS(const std::vector<BriefNode *> &source, bool reverse,
} }
} }
std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { std::vector<std::vector<Node *>> SubgraphDetector::ExtractSubGraphs() {
// Run the Extract algorithm to find all subgraphs. // Run the Extract algorithm to find all subgraphs.
std::vector<Node *> marked_nodes; std::vector<Node *> marked_nodes;
// We use brief_node_map to represent the original graph in order to avoid // We use brief_node_map to represent the original graph in order to avoid
// changing the original graph. // changing the original graph.
std::unordered_map<int, BriefNode *> brief_node_map; std::unordered_map<int, BriefNode *> brief_node_map;
for (auto &node : GraphTraits<DataFlowGraph>(*graph_).nodes_in_TS()) { std::unordered_set<int32_t> valid_node_ids;
for (auto *node : graph_->Nodes()) {
valid_node_ids.insert(node->id());
}
for (auto &node : framework::ir::GraphTraits::TS(*graph_)) {
brief_node_map[node.id()] = new BriefNode(&node); brief_node_map[node.id()] = new BriefNode(&node);
if (node.attr(kMarkerAttrName).Bool()) { if (Agent(&node).marked()) {
marked_nodes.push_back(&node); marked_nodes.push_back(&node);
} }
} }
...@@ -213,26 +282,34 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -213,26 +282,34 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
node_map_t node_map; // id to ptr node_map_t node_map; // id to ptr
for (auto *n : marked_nodes) { for (auto *n : marked_nodes) {
// n's parent == n.id means it is the ancestor // n's parent == n.id means it is the ancestor
n->attr(kUnionFindParent).Int32() = n->id(); Agent(n).set_union_find_parent(n->id());
node_map[n->id()] = n; node_map[n->id()] = n;
} }
// create breif node map // create breif node map
for (auto &itr : brief_node_map) { for (auto &itr : brief_node_map) {
for (Node *node : itr.second->node->inlinks) { for (Node *node : itr.second->node->inputs) {
itr.second->inlinks.push_back(brief_node_map[node->id()]); if (!valid_node_ids.count(node->id())) {
LOG(INFO) << "invalid node id " << node->id();
continue;
}
itr.second->inlinks.push_back(brief_node_map.at(node->id()));
} }
for (Node *node : itr.second->node->outlinks) { for (Node *node : itr.second->node->outputs) {
itr.second->outlinks.push_back(brief_node_map[node->id()]); if (!valid_node_ids.count(node->id())) {
LOG(INFO) << "invalid node id " << node->id();
continue;
}
itr.second->outlinks.push_back(brief_node_map.at(node->id()));
} }
} }
for (auto &itr : brief_node_map) { for (auto &itr : brief_node_map) {
BriefNode *brief_node = itr.second; BriefNode *brief_node = itr.second;
if (!brief_node->node->attr(kMarkerAttrName).Bool()) { if (!Agent(brief_node->node).marked()) {
VLOG(40) << brief_node->node->id() << " node not a trt candicate."; VLOG(4) << brief_node->node->id() << " node not a trt candidate.";
continue; continue;
} }
...@@ -254,7 +331,7 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -254,7 +331,7 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
std::unordered_set<BriefNode *> contract_nodes; std::unordered_set<BriefNode *> contract_nodes;
for (auto *out : brief_node->outlinks) { for (auto *out : brief_node->outlinks) {
// must be an trt candidate // must be an trt candidate
if (!out->node->attr(kMarkerAttrName).Bool()) continue; if (!Agent(out->node).marked()) continue;
// get all dst input nodes except src. // get all dst input nodes except src.
std::vector<BriefNode *> source_nodes; std::vector<BriefNode *> source_nodes;
for (auto *n : out->inlinks) { for (auto *n : out->inlinks) {
...@@ -289,9 +366,8 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -289,9 +366,8 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
std::unordered_map<int /*ancestor*/, std::vector<Node *>> clusters; std::unordered_map<int /*ancestor*/, std::vector<Node *>> clusters;
for (auto *n : marked_nodes) { for (auto *n : marked_nodes) {
if (n->type() == Node::Type::kFunction) { if (n->IsOp()) {
clusters[UnionFindGetAncestor(node_map, clusters[UnionFindGetAncestor(node_map, Agent(n).union_find_parent())]
n->attr(kUnionFindParent).Int32())]
.push_back(n); .push_back(n);
} }
} }
...@@ -304,28 +380,59 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() { ...@@ -304,28 +380,59 @@ std::vector<std::vector<Node *>> SubGraphSplitter::ExtractSubGraphs() {
return result; return result;
} }
void SubGraphFuse::operator()() { ReplaceNodesWithSubGraphs(); } void SubGraphFuser::operator()() { ReplaceNodesWithSubGraphs(); }
void RemoveIntermediateOutputInSubgraph(const std::vector<Node *> &subgraph,
Graph *graph,
std::vector<Node *> *outputs) {
std::unordered_set<Node *> subgraph_set(subgraph.begin(), subgraph.end());
std::unordered_set<Node *> valid_output;
for (auto *output : *outputs) {
int num_used = 0;
for (auto *node : output->outputs) {
if (!subgraph_set.count(node)) ++num_used;
if (num_used > 0) valid_output.insert(output);
}
}
outputs->assign(valid_output.begin(), valid_output.end());
}
void DetachDeletedNodes(framework::ir::Graph *graph) {
std::unordered_set<const Node *> nodes;
for (auto *node : graph->Nodes()) {
if (Agent(node).deleted()) {
node->inputs.clear();
node->outputs.clear();
}
}
}
void SubGraphFuse::ReplaceNodesWithSubGraphs() { void SubGraphFuser::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubGraphSplitter(graph_, node_inside_subgraph_teller_)(); auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) { for (auto &subgraph : subgraphs) {
if (subgraph.size() <= argument_->Get<int>("minimum_subgraph_size")) if (subgraph.size() <= min_subgraph_size_) continue;
continue; LOG(INFO) << "detect a subgraph size " << subgraph.size();
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end()); std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block // replace this sub-graph with the first node. Two steps: 1. Create a Block
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph // Node that contains this subgraph 2. Mark the nodes inside the sub-graph
// as deleted. 3. Replace the deleted node with the new Block Node. // as deleted. 3. Replace the deleted node with the new Block Node.
auto *block_node = static_cast<FunctionBlock *>( framework::OpDesc empty_desc;
graph_->nodes.Create(Node::Type::kFunctionBlock)); empty_desc.SetType("tensorrt_engine");
auto *block_node = graph_->CreateOpNode(&empty_desc);
Agent(block_node).set_subgraph({});
auto io = ExtractInputAndOutputOfSubGraph(subgraph); auto io = ExtractInputAndOutputOfSubGraph(subgraph);
block_node->inlinks = std::move(io.first); block_node->inputs = std::move(io.first);
block_node->outlinks = std::move(io.second); block_node->outputs = std::move(io.second);
RemoveIntermediateOutputInSubgraph(subgraph, graph_, &block_node->outputs);
for (auto *node : subgraph) { for (auto *node : subgraph) {
// TODO(Superjomn) need a unified mechanism to treat deleted node in each // TODO(Superjomn) need a unified mechanism to treat deleted node in each
// pass. // pass.
node->SetDeleted(); Agent(node).set_deleted(true);
block_node->subgraph.push_back(node); Agent(block_node).subgraph()->push_back(node);
} }
// Change all the sub-graph's inputs and outputs corresponding inlink and // Change all the sub-graph's inputs and outputs corresponding inlink and
...@@ -339,16 +446,92 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() { ...@@ -339,16 +446,92 @@ void SubGraphFuse::ReplaceNodesWithSubGraphs() {
std::unordered_set<Node *> uniq(nodes.begin(), nodes.end()); std::unordered_set<Node *> uniq(nodes.begin(), nodes.end());
nodes.assign(uniq.begin(), uniq.end()); nodes.assign(uniq.begin(), uniq.end());
}; };
for (auto *i : block_node->inlinks) { for (auto *i : block_node->inputs) {
inlink_or_outlink_cleaner(i->outlinks); inlink_or_outlink_cleaner(i->outputs);
} }
for (auto *&o : block_node->outlinks) { for (auto *&o : block_node->outputs) {
inlink_or_outlink_cleaner(o->inlinks); inlink_or_outlink_cleaner(o->inputs);
} }
} }
// DetachDeletedNodes(graph_);
FilterRedundantOutputOfSubGraph(graph_); FilterRedundantOutputOfSubGraph(graph_);
} }
inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) {
return node.inputs.size() == n;
}
NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
PADDLE_ENFORCE(!source.empty(),
"Start points of topological sorting should not be empty!");
// CHECK all the inputs' in-degree is 0
for (auto *node : source) {
PADDLE_ENFORCE(CheckNodeIndegreeEquals(*node, 0));
}
std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited;
while (!to_visit.empty()) {
std::vector<Node *> queue(to_visit.begin(), to_visit.end());
for (auto *p : queue) {
if (Agent(p).deleted()) {
visited.insert(p);
to_visit.erase(p);
}
inlink_visited.clear();
std::copy_if(p->inputs.begin(), p->inputs.end(),
std::back_inserter(inlink_visited),
[&](Node *x) -> bool { return visited.count(x) != 0; });
if (inlink_visited.size() == p->inputs.size()) {
sorted_.push_back(p);
for (auto *_ : p->outputs) {
if (!visited.count(_)) {
to_visit.insert(_);
}
}
to_visit.erase(p);
visited.insert(p);
}
}
}
}
NodesTSIterator::NodesTSIterator(const NodesTSIterator &other)
: sorted_(other.sorted_), cursor_(other.cursor_) {}
Node &NodesTSIterator::operator*() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
return *sorted_[cursor_];
}
NodesTSIterator &NodesTSIterator::operator++() {
if (++cursor_ >= sorted_.size()) {
sorted_.clear();
cursor_ = 0;
}
return *this;
}
NodesTSIterator &NodesTSIterator::operator=(const NodesTSIterator &other) {
cursor_ = other.cursor_;
sorted_ = other.sorted_;
return *this;
}
bool NodesTSIterator::operator==(const NodesTSIterator &other) {
return sorted_ == other.sorted_ && cursor_ == other.cursor_;
}
Node *NodesTSIterator::operator->() {
PADDLE_ENFORCE_LT(cursor_, sorted_.size());
return sorted_[cursor_];
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the the class to partition a graph.
*/
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/helper.h"
namespace paddle {
namespace inference {
namespace analysis {
using framework::ir::Graph;
const char kIsFunctionNode[] = "__is_function_node__";
const char kFunctionNodeSubGraph[] = "__function_node_sub_graph__";
const char kSubgraphSplitterMarkerAttrName[] =
"_sub_graph_splitter_inside_sub_graph";
/*
* Detect the nodes in a sub-graph that meet some conditions. This class doesn't
* modify the graph.
*/
class SubgraphDetector {
public:
// Tell whether a node is inside a sub-graph.
using NodeInsideSubgraphTeller =
std::function<bool(const framework::ir::Node *)>;
SubgraphDetector(Graph *graph, const NodeInsideSubgraphTeller &teller)
: graph_(graph), node_inside_subgraph_teller_(teller) {}
std::vector<std::vector<framework::ir::Node *>> operator()();
protected:
// Mark the nodes inside the accepted sub-graph using
// node_inside_subgraph_teller.
void MarkNodesInsideSubGraph();
// Merge the marked nodes into sub-graphs and return the sub-graphs.
std::vector<std::vector<framework::ir::Node *>> ExtractSubGraphs();
private:
Graph *graph_;
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
};
/*
* SubGraphFuser - Replace some nodes with the sub-graph node they are inside.
* To some extent, the TensorRT engine is just a fusion op for a model.
*/
class SubGraphFuser {
public:
using NodeInsideSubgraphTeller = SubgraphDetector::NodeInsideSubgraphTeller;
SubGraphFuser(Graph *graph, const NodeInsideSubgraphTeller &teller,
int min_subgraph_size)
: graph_(graph),
node_inside_subgraph_teller_(teller),
min_subgraph_size_{min_subgraph_size} {}
// The main method which run all the logic.
void operator()();
protected:
// Remove the nodes inside sub-graphs and replace with the SubGraphNode.
void ReplaceNodesWithSubGraphs();
private:
Graph *graph_;
NodeInsideSubgraphTeller node_inside_subgraph_teller_;
int min_subgraph_size_;
};
struct NodeWrapper {
bool deleted{false};
bool marked{false};
int union_find_parent{-1};
std::vector<framework::ir::Node *> subgraph;
};
/*
* ir::Node agent for subgraph detector.
*/
struct Agent {
explicit Agent(framework::ir::Node *x) : x_(x) {}
NodeWrapper &wrapper() {
if (!x_->IsWrappedBy<NodeWrapper>()) {
x_->WrappedBy<NodeWrapper>(new NodeWrapper);
}
return x_->template Wrapper<NodeWrapper>();
}
bool deleted() { return wrapper().deleted; }
void set_deleted(bool x) { wrapper().deleted = x; }
bool marked() { return wrapper().marked; }
void set_marked(bool x) { wrapper().marked = x; }
void set_subgraph(const std::vector<framework::ir::Node *> &x) {
wrapper().subgraph = x;
}
int union_find_parent() { return wrapper().union_find_parent; }
void set_union_find_parent(int v) { wrapper().union_find_parent = v; }
std::vector<framework::ir::Node *> *subgraph() { return &wrapper().subgraph; }
std::vector<framework::ir::Node *> &inputs() { return x_->inputs; }
std::vector<framework::ir::Node *> &outputs() { return x_->outputs; }
private:
framework::ir::Node *x_;
};
// Topological sorting iterator on nodes.
struct NodesTSIterator
: public std::iterator<std::forward_iterator_tag, framework::ir::Node *> {
NodesTSIterator() = default;
explicit NodesTSIterator(const std::vector<framework::ir::Node *> &source);
NodesTSIterator(NodesTSIterator &&other)
: sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
other.cursor_ = 0;
}
NodesTSIterator(const NodesTSIterator &other);
framework::ir::Node &operator*();
NodesTSIterator &operator++();
// TODO(Superjomn) current implementation just compare the first
// element, need to compare the graph and all the elements in the queue and
// set.
NodesTSIterator &operator=(const NodesTSIterator &other);
bool operator==(const NodesTSIterator &other);
bool operator!=(const NodesTSIterator &other) { return !(*this == other); }
framework::ir::Node *operator->();
private:
std::vector<framework::ir::Node *> sorted_;
size_t cursor_{0};
};
// The nodes those have no input will be treated as start points.
static std::vector<framework::ir::Node *> ExtractStartPoints(const Graph &g) {
std::vector<framework::ir::Node *> result;
for (auto *node : g.Nodes()) {
if (node->inputs.empty()) {
result.push_back(node);
}
}
return result;
}
static iterator_range<NodesTSIterator> TopologicalSort(const Graph &g) {
auto start_points = ExtractStartPoints(g);
PADDLE_ENFORCE(!start_points.empty());
NodesTSIterator x(start_points);
return iterator_range<NodesTSIterator>(NodesTSIterator(start_points),
NodesTSIterator());
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -12,120 +12,91 @@ ...@@ -12,120 +12,91 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/framework/proto_desc.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/io.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
using framework::proto::ProgramDesc; using framework::ir::Node;
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>> &nodes); const std::unordered_set<Node *> &nodes);
bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
ANALYSIS_ARGUMENT_CHECK_FIELD(argument)
ANALYSIS_ARGUMENT_CHECK_FIELD(argument->origin_program_desc) std::unique_ptr<framework::ir::Graph> graph) const {
// The transformed_program_desc should inherit all the VarDesc and BlockDesc framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph.get());
// from the original program desc. The operators of the main block(the first
// block) should rewritten by data flow graph. auto teller =
argument->transformed_program_desc.reset( Get<SubgraphDetector::NodeInsideSubgraphTeller>("tensorrt_node_teller");
new ProgramDesc(*argument->origin_program_desc));
argument->transformed_program_desc->mutable_blocks(framework::kRootBlockIndex)
->clear_ops();
desc_ = argument->transformed_program_desc.get();
argument_ = argument;
return true;
}
bool DataFlowGraphToFluidPass::Finalize() { return true; } SubGraphFuser fuser(graph.get(), teller, 2 /*min subgraph size*/);
fuser();
void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { for (auto *node : graph->Nodes()) {
// FilterRedundantOutputOfSubGraph(graph); if (node->IsOp() && !Agent(node).subgraph()->empty()) {
for (auto &node : GraphTraits<DataFlowGraph>(*graph).nodes_in_TS()) { CreateTensorRTOp(node, graph.get());
if (node.deleted()) continue;
switch (node.type()) { std::unordered_set<const Node *> nodes2remove(
case Node::Type::kFunction: { Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
AddFluidOp(&node); framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove);
} break;
case Node::Type::kFunctionBlock: {
AddEngineOp(&node);
} break;
default:
continue;
} }
} }
if (argument_->Has(framework::ir::kParamScopeAttr)) { std::unordered_set<const Node *> nodes2remove;
LOG(WARNING) << "parameter changes in the scope takes effect"; for (auto *node : graph->Nodes()) {
if (node->IsOp() && Agent(node).deleted()) {
nodes2remove.insert(node);
} }
}
framework::ir::GraphSafeRemoveNodes(graph.get(), nodes2remove);
PADDLE_ENFORCE(argument_->transformed_program_desc.get()); return graph;
} }
void DataFlowGraphToFluidPass::AddFluidOp(Node *node) { void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
PADDLE_ENFORCE(node); Graph *graph) const {
PADDLE_ENFORCE(node->IsFunction()); auto *op_desc = node->Op();
PADDLE_ENFORCE(node->pb_desc() || !node->pb_msg().empty(), static int counter{0};
"node has invalid protobuf repr."); auto &subgraph = *Agent(node).subgraph();
PADDLE_ENFORCE(!subgraph.empty());
// currently only the main block is analyzed.
PADDLE_ENFORCE(desc_);
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto *op = main_block->add_ops();
if (node->pb_desc()) { // An fake block desc.
auto *ori_op = static_cast<framework::proto::OpDesc *>(node->pb_desc()); framework::proto::BlockDesc block_proto;
*op = framework::BlockDesc block_desc(nullptr, &block_proto);
*ori_op; // copy the attributes, by default, these will not be changed block_desc.Proto()->set_parent_idx(-1);
// by analysis phrase. block_desc.Proto()->set_idx(0);
// The inputs and outputs of the existing ops are not changed by tensorrt for (auto *node : subgraph) {
// subgraph pass. auto *op = block_desc.AppendOp();
// NOTE It might be changed by other passes in the long run. *op->Proto() = *node->Op()->Proto();
} else {
op->ParseFromString(node->pb_msg());
} }
}
void CreateTrtEngineOp(Node *node, Argument *argument,
framework::proto::BlockDesc *block) {
PADDLE_ENFORCE(argument->main_dfg.get());
const DataFlowGraph &graph = *(argument->main_dfg);
static int counter{0};
PADDLE_ENFORCE(node->IsFunctionBlock());
framework::OpDesc desc;
auto *func = static_cast<FunctionBlock *>(node);
// collect inputs // collect inputs
std::unordered_set<std::string> input_names; std::unordered_set<std::string> input_names;
std::unordered_set<std::string> input_names_with_id; std::unordered_set<std::string> input_names_with_id;
for (auto *x : func->inlinks) { for (auto *x : node->inputs) {
input_names.insert(x->name()); input_names.insert(x->Name());
input_names_with_id.insert(x->name() + std::to_string(x->id())); input_names_with_id.insert(x->Name() + std::to_string(x->id()));
} }
desc.SetInput( op_desc->SetInput(
"Xs", std::vector<std::string>(input_names.begin(), input_names.end())); "Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
std::unordered_set<std::string> output_names; std::unordered_set<std::string> output_names;
std::unordered_set<std::string> output_names_with_id; std::unordered_set<std::string> output_names_with_id;
for (auto *x : func->outlinks) { for (auto *x : node->outputs) {
output_names.insert(x->name()); output_names.insert(x->Name());
output_names_with_id.insert(x->name() + std::to_string(x->id())); output_names_with_id.insert(x->Name() + std::to_string(x->id()));
} }
desc.SetOutput( op_desc->SetOutput(
"Ys", std::vector<std::string>(output_names.begin(), output_names.end())); "Ys", std::vector<std::string>(output_names.begin(), output_names.end()));
desc.SetType("tensorrt_engine"); op_desc->SetType("tensorrt_engine");
std::unordered_map<std::string, std::string> output_name_map; std::unordered_map<std::string, std::string> output_name_map;
...@@ -134,7 +105,7 @@ void CreateTrtEngineOp(Node *node, Argument *argument, ...@@ -134,7 +105,7 @@ void CreateTrtEngineOp(Node *node, Argument *argument,
// Why we do this? // Why we do this?
// During the transition from fluid OP to tensorrt OP, we map // During the transition from fluid OP to tensorrt OP, we map
// the input and output Tensor(fluid data structure) of fluid OP // the input and output Tensor(fluid data structure) of fluid OP
// to the correspondin ITensor (trt data structure) through the // to the corresponding ITensor (trt data structure) through the
// Tensor name. When we set up ITensor for an variable, we must // Tensor name. When we set up ITensor for an variable, we must
// ensure that it has not been set before. // ensure that it has not been set before.
// If there is variable in the fluid graph, which is not only the // If there is variable in the fluid graph, which is not only the
...@@ -142,21 +113,22 @@ void CreateTrtEngineOp(Node *node, Argument *argument, ...@@ -142,21 +113,22 @@ void CreateTrtEngineOp(Node *node, Argument *argument,
// So we have to rename the variable in the subgraph to make sure // So we have to rename the variable in the subgraph to make sure
// it is either an OP's input or an OP's output. // it is either an OP's input or an OP's output.
auto subgraph_nodes = func->subgraph; auto &subgraph_nodes = *Agent(node).subgraph();
for (int index = 0; index < block->ops_size(); index++) { for (int index = 0; index < block_desc.OpSize(); index++) {
framework::proto::OpDesc *op = block->mutable_ops(index); framework::proto::OpDesc *op = block_desc.Op(index)->Proto();
auto correspond_node = subgraph_nodes[index]; auto correspond_node = subgraph_nodes[index];
PADDLE_ENFORCE_EQ(correspond_node->name(), op->type()); PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type());
std::unordered_map<std::string, size_t> var2id; std::unordered_map<std::string, size_t> var2id;
for (auto *in_var : correspond_node->inlinks) { for (auto *in_var : correspond_node->inputs) {
var2id[in_var->name()] = in_var->id(); var2id[in_var->Name()] = in_var->id();
} }
// rename for the input variables of op inside subgraph // rename for the input variables of op inside subgraph
for (int i = 0; i < op->inputs_size(); i++) { for (int i = 0; i < op->inputs_size(); i++) {
framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i); // one input
auto *in_var = op->mutable_inputs(i);
std::vector<std::string> replaced_names; std::vector<std::string> replaced_names;
for (int k = 0; k < in_var->arguments_size(); k++) { for (int k = 0; k < in_var->arguments_size(); k++) { // all the arguments
std::string arg_value = in_var->arguments(k); std::string arg_value = in_var->arguments(k);
std::string arg_value_with_id = std::string arg_value_with_id =
arg_value + std::to_string(var2id[arg_value]); arg_value + std::to_string(var2id[arg_value]);
...@@ -172,8 +144,8 @@ void CreateTrtEngineOp(Node *node, Argument *argument, ...@@ -172,8 +144,8 @@ void CreateTrtEngineOp(Node *node, Argument *argument,
} }
} }
var2id.clear(); var2id.clear();
for (auto out_var : correspond_node->outlinks) { for (auto out_var : correspond_node->outputs) {
var2id[out_var->name()] = out_var->id(); var2id[out_var->Name()] = out_var->id();
} }
// rename for the output variables of op inside subgraph // rename for the output variables of op inside subgraph
...@@ -195,91 +167,54 @@ void CreateTrtEngineOp(Node *node, Argument *argument, ...@@ -195,91 +167,54 @@ void CreateTrtEngineOp(Node *node, Argument *argument,
} }
} }
} }
// When tensorrt engine runs at the end of the operation, // When tensorrt engine runs at the end of the operation,
// output_mapping help us copy the data from the renamed ITensor // output_mapping help us copy the data from the renamed ITensor
// to Tensor. // to Tensor.
std::vector<std::string> output_mapping; std::vector<std::string> output_mapping;
for (auto name : output_names) { for (auto name : output_names) {
// LOG(INFO) << name << " " << output_name_map.size();
PADDLE_ENFORCE(output_name_map.count(name) != 0); PADDLE_ENFORCE(output_name_map.count(name) != 0);
output_mapping.push_back(output_name_map[name]); output_mapping.push_back(output_name_map[name]);
} }
PADDLE_ENFORCE(!block->vars().empty(), "the block has no var-desc"); *block_desc.Proto()->mutable_vars() =
const_cast<framework::ProgramDesc *>(&graph->program())
->Proto()
->blocks(0)
.vars();
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(),
"the block has no var-desc");
PADDLE_ENFORCE(!output_mapping.empty());
// Set attrs // Set attrs
SetAttr(op_desc->Proto(), "subgraph",
SetAttr(desc.Proto(), "subgraph", block->SerializeAsString()); block_desc.Proto()->SerializeAsString());
SetAttr(desc.Proto(), "max_batch_size", argument->Get<int>("max_batch_size")); SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
SetAttr(desc.Proto(), "workspace_size", argument->Get<int>("workspace_size")); SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
SetAttr(desc.Proto(), "engine_uniq_key", "trt-" + std::to_string(counter++)); SetAttr(op_desc->Proto(), "engine_uniq_key",
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes())); "trt-" + std::to_string(counter++));
SetAttr(desc.Proto(), "output_name_mapping", output_mapping); SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
node->SetPbMsg(desc.Proto()->SerializeAsString()); SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
} }
std::vector<std::string> ExtractParameters( std::vector<std::string> ExtractParameters(
const std::vector<std::unique_ptr<Node>> &nodes) { const std::unordered_set<Node *> &nodes) {
std::vector<std::string> parameters; std::vector<std::string> parameters;
for (const auto &node : nodes) { for (const auto &node : nodes) {
if (!node->IsValue()) continue; if (!node->IsVar()) continue;
PADDLE_ENFORCE(!node->pb_msg().empty(), "pb_msg should be set first"); if (node->Var()->Persistable()) {
framework::proto::VarDesc var; parameters.push_back(node->Name());
var.ParseFromString(node->pb_msg());
if (var.persistable()) {
parameters.push_back(var.name());
} }
} }
return parameters; return parameters;
} }
void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
// TODO(Superjomn) Here need to expose some arguments for default setting.
PADDLE_ENFORCE(node->IsFunctionBlock());
auto *block_node = static_cast<FunctionBlock *>(node);
framework::proto::BlockDesc proto;
framework::BlockDesc block_desc(nullptr, &proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
VLOG(40) << "origin variable size: "
<< argument_->origin_program_desc->blocks(0).vars().size();
VLOG(40) << "transformed variable size: "
<< block_desc.Proto()->vars().size();
// copy ops.
for (auto *node : block_node->subgraph) {
auto *op = block_desc.AppendOp();
PADDLE_ENFORCE(!node->pb_msg().empty());
op->Proto()->ParseFromString(node->pb_msg());
}
*block_desc.Proto()->mutable_vars() =
argument_->origin_program_desc->blocks(0).vars();
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
CreateTrtEngineOp(node, argument_, block_desc.Proto());
auto *main_block = desc_->mutable_blocks(framework::kRootBlockIndex);
auto *op = main_block->add_ops();
PADDLE_ENFORCE(!node->pb_msg().empty(), "failed to set desc for block");
op->ParseFromString(node->pb_msg());
}
namespace {
class DFG_DebuggerPass : public DFG_GraphvizDrawPass {
public:
using Config = DFG_GraphvizDrawPass::Config;
explicit DFG_DebuggerPass(const Config &config)
: DFG_GraphvizDrawPass(config) {}
std::string repr() const override { return "dfg-to-fluid-debuger-pass"; }
bool Finalize() override { return true; }
};
} // namespace
AnalysisPass *DataFlowGraphToFluidPass::CreateGraphvizDebugerPass() const {
return new DFG_DebuggerPass(DFG_GraphvizDrawPass::Config(
FLAGS_IA_graphviz_log_root,
"data_flow_graph_to_fluid_graphviz_debugger"));
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
REGISTER_PASS(tensorrt_subgraph_pass,
paddle::inference::analysis::TensorRtSubgraphPass)
.RequirePassAttr("tensorrt_node_teller")
.RequirePassAttr("max_batch_size")
.RequirePassAttr("workspace_size");
...@@ -12,31 +12,24 @@ ...@@ -12,31 +12,24 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/model_store_pass.h" #pragma once
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <gflags/gflags.h> #include "paddle/fluid/framework/ir/pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/analyzer.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
DEFINE_string(inference_model_dir, "", "Model path"); class TensorRtSubgraphPass : public framework::ir::FusePassBase {
public:
TEST(DFG_StorePass, test) { std::unique_ptr<framework::ir::Graph> ApplyImpl(
Analyzer analyzer; std::unique_ptr<framework::ir::Graph> graph) const override;
Argument argument(FLAGS_inference_model_dir);
argument.model_output_store_path.reset(
new std::string("./_dfg_store_pass_tmp"));
// disable storage in alalyzer
FLAGS_IA_output_storage_path = "";
analyzer.Run(&argument);
ModelStorePass pass; private:
pass.Initialize(&argument); void CreateTensorRTOp(framework::ir::Node *x,
pass.Run(argument.main_dfg.get()); framework::ir::Graph *graph) const;
} void CleanIntermediateOutputs(framework::ir::Node *node);
};
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/inference/analysis/model_store_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
void ModelStorePass::Run(DataFlowGraph *x) {
if (!argument_->fluid_model_param_path) {
PADDLE_ENFORCE_NOT_NULL(argument_->fluid_model_dir);
argument_->fluid_model_param_path.reset(
new std::string(*argument_->fluid_model_dir + "param"));
}
PADDLE_ENFORCE_NOT_NULL(argument_->model_output_store_path);
// Directly copy param file to destination.
std::stringstream ss;
// NOTE these commands only works on linux.
ss << "mkdir -p " << *argument_->model_output_store_path;
VLOG(30) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
ss.str("");
ss << "cp " << *argument_->fluid_model_dir << "/*"
<< " " << *argument_->model_output_store_path;
VLOG(30) << "run command: " << ss.str();
PADDLE_ENFORCE_EQ(system(ss.str().c_str()), 0);
// Store program
PADDLE_ENFORCE_NOT_NULL(argument_->transformed_program_desc,
"program desc is not transformed, should call "
"DataFlowGraphToFluidPass first.");
VLOG(30) << "store analyzed program to "
<< *argument_->model_output_store_path;
const std::string program_output_path =
*argument_->model_output_store_path + "/__model__";
std::ofstream file(program_output_path, std::ios::binary);
PADDLE_ENFORCE(file.is_open(), "failed to open %s to write.",
program_output_path);
const std::string serialized_message =
argument_->transformed_program_desc->SerializeAsString();
file.write(serialized_message.c_str(), serialized_message.size());
}
bool ModelStorePass::Finalize() { return true; }
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/node.h"
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace analysis {
std::vector<Dot::Attr> Value::dot_attrs() const {
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
Dot::Attr("shape", "box"),
Dot::Attr("fillcolor", "red")});
}
std::vector<Dot::Attr> Function::dot_attrs() const {
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
Dot::Attr("shape", "diamond"),
Dot::Attr("fillcolor", "yellow")});
}
Node *NodeMap::Create(Node::Type type) {
switch (type) {
case Node::Type::kFunction:
nodes_.emplace_back(new Function);
break;
case Node::Type::kValue:
nodes_.emplace_back(new Value);
break;
case Node::Type::kFunctionBlock:
nodes_.emplace_back(new FunctionBlock);
break;
default:
PADDLE_THROW("Not supported node type.");
}
nodes_.back()->id_ = size() - 1;
return nodes_.back().get();
}
Node *NodeMap::GetMutable(size_t id) {
PADDLE_ENFORCE_GT(size(), id);
return nodes_[id].get();
}
const Node &NodeMap::Get(size_t id) const {
PADDLE_ENFORCE_GT(size(), id);
return *nodes_[id].get();
}
void NodeMap::Delete(size_t id) {
PADDLE_ENFORCE_LT(id, size());
nodes_[id]->SetDeleted();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the Node class and its subclasses. A Node is the basis
* analysis element in a computation graph.
* There are basically two kinds of nodes, the function node and value node.
*/
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/inference/analysis/device.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace inference {
namespace analysis {
class NodeMap;
// A helper class to maintain the status from Pass.
struct AnyAttr {
using any_t =
boost::variant<bool, float, int32_t, int64_t, void *, std::string>;
// NOTE T should be a primary type or a struct combined by several primary
// types.
// NOTE the STL containers should not use here.
// Some usages
// Attr attr;
// attr.Bool() = true;
bool &Bool() { return As<bool>(); }
float &Float() { return As<float>(); }
int32_t &Int32() { return As<int32_t>(); }
int64_t &Int64() { return As<int64_t>(); }
void *&Pointer() { return As<void *>(); }
std::string &String() { return As<std::string>(); }
template <typename T>
T &As() {
if (type_index_ == typeid(AnyAttr)) {
type_index_ = typeid(T);
any_data_ = T();
} else {
PADDLE_ENFORCE(type_index_ == typeid(T), "fetch error type");
}
return boost::get<T>(any_data_);
}
private:
any_t any_data_;
std::type_index type_index_{typeid(AnyAttr)};
};
/*
* Node Representation.
*
* This is a very important class for analysis. It is the base class of all
* nodes computed by a program that may be used as operands to other nodes.
* Node is the super class of other important classes such as Function and
* Value, some nodes can have a name.
*/
class Node {
public:
// Node type. NOTE the new node types should add here.
enum class Type { kNone = -1, kFunction, kValue, kFunctionBlock };
Node() = default;
// Cast to a subclass type, Function for example.
template <typename Subclass>
Subclass &As() {
return *dynamic_cast<Subclass *>(this);
}
// Formatted representation of this Node.
virtual std::string repr() const {
return name() + "(" + std::to_string(id()) + ")";
}
// DOT node representation. One Node type can customize its own node
// representation.
virtual std::vector<Dot::Attr> dot_attrs() const {
return std::vector<Dot::Attr>({Dot::Attr("style", "filled")});
}
// Get an additional attribute and convert it to T data type. NOTE this will
// silently create a new attribute if not exists.
AnyAttr &attr(const std::string &name) const { return attrs_[name]; }
int id() const { return id_; }
// The Protobuf description is set/get with a void* to decouple Node interface
// from a specific kind of Protobuf message.
void SetPbDesc(void *pb) { attr("pb_desc").Pointer() = pb; }
void *pb_desc() const { return attr("pb_desc").Pointer(); }
void SetPbMsg(const std::string &s) { attr("pb_msg").String() = s; }
const std::string &pb_msg() const { return attr("pb_msg").String(); }
void SetDeleted() { deleted_ = true; }
bool deleted() const { return deleted_; }
void SetName(const std::string &name) { name_ = name; }
const std::string &name() const { return name_; }
void SetType(Type type) { type_ = type; }
Type type() const { return type_; }
// Input links.
std::vector<Node *> inlinks;
// Output links.
std::vector<Node *> outlinks;
// Type checks.
bool IsFunction() const { return type_ == Node::Type::kFunction; }
bool IsValue() const { return type_ == Node::Type::kValue; }
bool IsFunctionBlock() const { return type_ == Node::Type::kFunctionBlock; }
virtual ~Node() {}
friend class NodeMap;
PADDLE_DISALLOW_COPY_AND_ASSIGN(Node);
protected:
// The id number not the name is a node's unique identifier in the computation
// graph.
int id_{-1};
std::string name_;
Type type_{Type::kNone};
// Mark this node is deleted by some pass.
bool deleted_{false};
mutable std::unordered_map<std::string, AnyAttr> attrs_;
};
class Function;
/*
* Value represents a value node, it has some attributes including dims, data
* type and so on.
*/
class Value : public Node {
public:
enum class DataType { kInt32, kInt64, kFloat32, kFloat64 };
using Dims = std::vector<int>;
void SetDataType(DataType data_type) { data_type_ = data_type; }
DataType data_type() const { return data_type_; }
void SetDims(const Dims &dims) { dims_ = dims; }
const Dims &dims() const { return dims_; }
Device device() const { return device_; }
void SetDevice(Device device) { device_ = device; }
std::vector<Dot::Attr> dot_attrs() const override;
PADDLE_DISALLOW_COPY_AND_ASSIGN(Value);
protected:
Value() { SetType(Node::Type::kValue); }
friend class NodeMap;
private:
DataType data_type_;
Dims dims_;
Device device_;
};
/*
* Function represents any kind of executable concepts that takes several Values
* as input, and outputs several Values.
*/
class Function : public Node {
public:
std::vector<Dot::Attr> dot_attrs() const override;
// Get the operator's type from Desc.
const std::string &func_type() const { return func_type_; }
// Set the operator's type.
void SetFuncType(const std::string &func_type) { func_type_ = func_type; }
PADDLE_DISALLOW_COPY_AND_ASSIGN(Function);
protected:
std::string func_type_;
Function() { SetType(Node::Type::kFunction); }
friend class NodeMap;
};
/*
* FunctionBlock is a Node that contains a sub-graph multiple Node.
*/
struct FunctionBlock : public Node {
std::string repr() const override { return "block-" + std::to_string(id()); }
std::vector<Node *> subgraph;
protected:
FunctionBlock() { SetType(Node::Type::kFunctionBlock); }
friend class NodeMap;
};
class NodeMap {
public:
// Create a new node with type.
Node *Create(Node::Type type);
// Get a node by its id.
Node *GetMutable(size_t id);
const Node &Get(size_t id) const;
void Delete(size_t id);
const std::vector<std::unique_ptr<Node>> &nodes() const { return nodes_; }
size_t size() const { return nodes_.size(); }
private:
std::vector<std::unique_ptr<Node>> nodes_;
std::unordered_map<std::string, Node *> map_;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/node.h"
#include <gtest/gtest.h>
namespace paddle {
namespace inference {
namespace analysis {
TEST(NodeAttr, bool) {
AnyAttr x;
x.Bool() = true;
ASSERT_EQ(x.Bool(), true);
}
TEST(NodeAttr, int32) {
AnyAttr x;
x.Int32() = 32;
ASSERT_EQ(x.Int32(), 32);
}
TEST(NodeAttr, string) {
AnyAttr x;
x.String() = "Hello";
ASSERT_EQ(x.String(), "Hello");
}
TEST(Node, Attr) {
// Node is an abstract class, use Value instead for they share the same Attr
// logic.
NodeMap nodes;
auto* node = nodes.Create(Node::Type::kValue);
node->attr("v0").Int32() = 2008;
ASSERT_EQ(node->attr("v0").Int32(), 2008);
node->attr("str").String() = "hello world";
ASSERT_EQ(node->attr("str").String(), "hello world");
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace inference {
namespace analysis {
bool PassManager::Initialize(Argument* argument) {
argument_ = argument;
for (auto& pass : data_) {
VLOG(30) << "Initializing pass [" << pass->repr() << "]";
if (!pass->Initialize(argument)) {
LOG(ERROR) << "Failed to initialize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
void DfgPassManager::RunAll() {
PADDLE_ENFORCE(argument_);
VLOG(30) << "Total " << data_.size() << " Analysys passes";
for (auto& pass : data_) {
string::PrettyLogEndl(string::Style::H1(), "* Running Analysis pass [%s]",
pass->repr());
pass->Run(argument_->main_dfg.get());
}
}
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file defines the logic of pass management. The analysis for inference is
* a pipeline of Passes, a PassManager is a agency that helps to manage the
* executation of the Passes.
*
* There are two modes of Passes, the first one is called NodePass and takes
* an Node as input and output; the second one is called DFGPass and takes a
* DFG(Data Flow Graph) as input and output. It is hard to put all the passes in
* the same pipeline, there are two kinds of PassManagers, both takes a DFG as
* input and output a DFG, but the Passes inside are different:
*
* 1. NodePassManager: the passes inside are all NodePasses, it can have
* different graph trivial algorithm, for example, DFS_NodePassManager will
* trigger the passes in depth first order;
* 2. DfgPassManager: the passes inside are all DfgPasses.
*/
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* PassManager is the base class for all pass managers, a pass manager has
* several Pass-es registered, and execute them in the linear order.
*/
class PassManager : public OrderedRegistry<AnalysisPass> {
public:
PassManager() = default;
// Call all the passes' Initialize methods. The desc and data_flow_graph are
// globally shared, so pass them as the arguemnts for all the pass managers.
virtual bool Initialize(const Argument& argument) { return false; }
virtual bool Initialize(Argument* argument);
// Call all the passes' Finalize methods.
virtual bool Finalize() {
for (auto& pass : data_) {
if (!pass->Finalize()) {
LOG(ERROR) << "Failed to finalize pass [" << pass->repr() << "]";
return false;
}
}
return true;
}
// Run all the passes.
virtual void RunAll() = 0;
// Short identifier.
virtual std::string repr() const = 0;
// Long description.
virtual std::string description() const = 0;
virtual ~PassManager() = default;
protected:
Argument* argument_{nullptr};
};
/*
* A pass manager that process a DFG.
*/
class DfgPassManager : public PassManager {
public:
DfgPassManager() = default;
void RunAll() override;
virtual ~DfgPassManager() = default;
};
} // namespace analysis
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h"
#include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include "paddle/fluid/inference/analysis/pass_manager.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
namespace paddle {
namespace inference {
namespace analysis {
class TestDfgPassManager final : public DfgPassManager {
public:
TestDfgPassManager() = default;
virtual ~TestDfgPassManager() = default;
// Short identifier.
std::string repr() const override { return "test-pass-manager"; }
// Long description.
std::string description() const override { return "test doc"; }
};
TEST(PassManager, DFG_pass_manager) {
TestDfgPassManager manager;
DFG_GraphvizDrawPass::Config config("./", "dfg.dot");
manager.Register("fluid-to-flow-graph", new FluidToDataFlowGraphPass);
manager.Register("graphviz", new DFG_GraphvizDrawPass(config));
manager.Register("dfg-to-fluid", new DataFlowGraphToFluidPass);
Argument argument(FLAGS_inference_model_dir);
ASSERT_TRUE(&argument);
ASSERT_TRUE(manager.Initialize(&argument));
manager.RunAll();
}
} // namespace analysis
} // namespace inference
} // namespace paddle
cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass)
set(analysis_deps ${analysis_deps}
ir_graph_build_pass
ir_analysis_pass
analysis_passes
CACHE INTERNAL "")
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace inference {
namespace analysis {
void IrAnalysisComposePass::RunImpl(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
if (argument->use_tensorrt_valid() && argument->use_tensorrt()) {
InitTensorRTAttrs(argument);
}
ApplyIrPasses(argument);
CollectFusionStatis(argument);
}
std::string IrAnalysisComposePass::repr() const {
return "ir-analysis-compose-pass";
}
void IrAnalysisComposePass::InitTensorRTAttrs(Argument *argument) {
if (argument->use_tensorrt_valid() && argument->use_tensorrt()) {
LOG(INFO) << "Initing TensorRT pass";
argument->SetTensorRtNodeTeller([](const framework::ir::Node *node) {
std::unordered_set<std::string> teller_set(
{"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "dropout"});
if (!node->IsOp()) return false;
if (teller_set.count(node->Op()->Type())) {
return true;
} else {
return false;
}
});
}
}
void IrAnalysisComposePass::ApplyIrPasses(Argument *argument) {
std::vector<std::string> passes({
"ir_graph_build_pass", "ir_analysis_pass",
});
for (const auto &pass : passes) {
VLOG(2) << "Run pass " << pass;
auto *the_pass = PassRegistry::Global().Retreive(pass);
the_pass->Run(argument);
}
}
void IrAnalysisComposePass::CollectFusionStatis(Argument *argument) {
if (!argument->main_graph().Has(framework::ir::kFuseStatisAttr)) {
LOG(INFO) << "argument has no fuse statis";
return;
}
argument->SetFusionStatis(
argument->main_graph().Get<Argument::fusion_statis_t>(
framework::ir::kFuseStatisAttr));
}
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -12,42 +12,35 @@ ...@@ -12,42 +12,35 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
/*
* This file defines ModelStorePass, which store the runtime DFG to a Paddle
* model in the disk, and that model can be reloaded for prediction.
*/
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h" #include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/passes.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
class ModelStorePass : public DataFlowGraphPass { /*
* The analysis pass to run a list of IR passes (like a function call).
* Currently, it should be the first pass of analysis phase.
*/
class IrAnalysisComposePass : public AnalysisPass {
public: public:
bool Initialize(Argument* argument) override { void RunImpl(Argument* argument) override;
if (!argument) { std::string repr() const override;
LOG(ERROR) << "invalid argument";
return false;
}
argument_ = argument;
return true;
}
void Run(DataFlowGraph* x) override; private:
void InitTensorRTAttrs(Argument* argument);
std::string repr() const override { return "DFG-store-pass"; } void ApplyIrPasses(Argument* argument);
std::string description() const override {
return R"DD(This file defines ModelStorePass, which store the runtime DFG to a Paddle
model in the disk, and that model can be reloaded for prediction again.)DD";
}
bool Finalize() override; void CollectFusionStatis(Argument* argument);
private: // Assign a Scope for IR passes to modify the weights.
Argument* argument_{nullptr}; void AssignScopeToModify(Argument* argument);
}; };
} // namespace analysis } // namespace analysis
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
namespace paddle {
namespace inference {
namespace analysis {
void IrAnalysisPass::RunImpl(Argument* argument) {
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
ARGUMENT_CHECK_FIELD(argument, main_program);
ARGUMENT_CHECK_FIELD(argument, scope);
auto* the_graph = argument->ReleaseMainGraph();
auto graph = std::unique_ptr<Graph>(the_graph);
// Apply passes.
IRPassManager the_ir_manager(argument);
graph = the_ir_manager.Apply(std::move(graph));
PADDLE_ENFORCE_GT(graph->Nodes().size(), 0);
argument->SetIrAnalyzedProgram(new framework::proto::ProgramDesc(
the_ir_manager.AcquireProgram(&graph, argument->main_program())));
argument->SetMainGraph(graph.release());
}
std::string IrAnalysisPass::repr() const { return "ir-analysis-pass"; }
} // namespace analysis
} // namespace inference
} // namespace paddle
...@@ -12,20 +12,25 @@ ...@@ -12,20 +12,25 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
/*
* This file contains all the flags that declared in Node::Attr.
*
* The Node::Attr is designed to share information between different passes, one
* can get other's attributes in a Node by the flags in this file.
*/
#pragma once #pragma once
#include <string>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
#define DECLARE_NODE_ATTR(flag__) const char ATTR_##flag__[] = #flag__; /*
* Perform IR analysis passes.
DECLARE_NODE_ATTR(supported_by_tensorrt) // bool *
* It is used to fuse some
*/
class IrAnalysisPass : public AnalysisPass {
public:
void RunImpl(Argument* argument) override;
std::string repr() const override;
};
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
......
...@@ -12,49 +12,62 @@ ...@@ -12,49 +12,62 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/analysis/fluid_to_ir_pass.h" #include "paddle/fluid/inference/analysis/passes/ir_graph_build_pass.h"
#include <paddle/fluid/framework/ir/fuse_pass_base.h>
#include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/io.h" #include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
extern void ReadBinaryFile(const std::string &filename, std::string *contents);
namespace analysis { namespace analysis {
void FluidToIrPass::EnableParamModify(const std::string &model_dir, void IrGraphBuildPass::RunImpl(Argument *argument) {
const std::string &prog_file, if (!argument->scope_valid()) {
const std::string &param_file) { argument->SetScope(new framework::Scope);
PADDLE_ENFORCE(argument_); }
argument_->Set(framework::ir::kParamScopeAttr, new framework::Scope);
// Load parameters.
VLOG(30) << "Loading parameters from " << model_dir;
LoadParams(&argument_->Get<framework::Scope>(framework::ir::kParamScopeAttr),
model_dir, prog_file, param_file);
}
bool FluidToIrPass::LoadParams(framework::Scope *scope, const std::string &dir, if (argument->model_dir_valid()) {
const std::string &prog_file, auto program = LoadModel(argument->model_dir(), argument->scope_ptr());
const std::string &param_file) { argument->SetMainProgram(program.release());
platform::CPUPlace place; } else if (argument->model_program_path_valid() &&
platform::CPUDeviceContext ctx(place); argument->model_params_path_valid()) {
framework::Executor executor(place); auto program =
PADDLE_ENFORCE(argument_->origin_program_desc.get()); LoadModel(argument->model_program_path(), argument->model_params_path(),
framework::ProgramDesc program(*argument_->origin_program_desc); argument->scope_ptr());
if ((!prog_file.empty()) && (!param_file.empty())) { argument->SetMainProgram(program.release());
LOG(INFO) << "load single model file from " << prog_file;
Load(&executor, scope, prog_file, param_file);
} else if (!dir.empty()) {
LOG(INFO) << "load from dir " << dir;
Load(&executor, scope, dir);
} else { } else {
LOG(ERROR) << "failed to load parameters"; PADDLE_THROW(
return false; "either model_dir or (program path and parameter path) should be set.");
} }
return true;
auto graph = std::unique_ptr<Graph>(new Graph(argument->main_program()));
argument->SetMainGraph(graph.release());
argument->main_graph().Set(framework::ir::kParamScopeAttr,
new framework::Scope *(argument->scope_ptr()));
} }
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &path, framework::Scope *scope) {
platform::CPUPlace place;
framework::Executor exe(place);
return Load(&exe, scope, path);
}
std::unique_ptr<framework::ProgramDesc> IrGraphBuildPass::LoadModel(
const std::string &program_path, const std::string &params_path,
framework::Scope *scope) {
platform::CPUPlace place;
framework::Executor exe(place);
return Load(&exe, scope, program_path, params_path);
}
std::string IrGraphBuildPass::repr() const { return "ir-graph-build-pass"; }
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
此差异已折叠。
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册