diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 9323882e4ee90ff4b9534a3e42d2c58b038f0af8..403c3dc4e907f06aeb54c1310df7d57fe8871798 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -39,6 +39,7 @@ cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) +cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) @@ -99,7 +100,7 @@ endif() if(WITH_NGRAPH) cc_library(ngraph_subgraph_pass SRCS ngraph_subgraph_pass.cc DEPS ngraph_bridge - analysis_helper subgraph_detector graph_pattern_detector pass fuse_pass_base ${op_library_DEPS}) + subgraph_detector fuse_pass_base ${op_library_DEPS}) set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h) file(APPEND ${pass_file} "USE_PASS(ngraph_subgraph_pass);\n") set(INFER_IR_PASSES ${INFER_IR_PASSES} ngraph_subgraph_pass CACHE INTERNAL "") diff --git a/paddle/fluid/framework/ir/ngraph_subgraph_pass.cc b/paddle/fluid/framework/ir/ngraph_subgraph_pass.cc index 6198fab7dcaf7cce229532e50c34e516c1697ba4..9778b6215ad117697a1ccbe88e922df7c4e6f568 100644 --- a/paddle/fluid/framework/ir/ngraph_subgraph_pass.cc +++ b/paddle/fluid/framework/ir/ngraph_subgraph_pass.cc @@ -20,8 +20,7 @@ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/ngraph_subgraph_pass.h" -#include "paddle/fluid/inference/analysis/helper.h" -#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/operators/ngraph/ngraph_bridge.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/pretty_log.h" @@ -30,8 +29,6 @@ namespace paddle { namespace framework { namespace ir { -namespace ANAT = paddle::inference::analysis; - std::string GenerateEngineKey(const std::set &engine_inputs, const std::set &engine_outputs, const std::string &size) { @@ -59,19 +56,18 @@ void NgraphSubgraphPass::ApplyImpl(Graph *graph) const { return !paddle::operators::NgraphBridge::isRegister(op_type); }; - ANAT::SubGraphFuser fuser(graph, teller, 0, "ngraph_engine"); + SubGraphFuser fuser(graph, teller, 0, "ngraph_engine"); fuser(); for (auto *node : graph->Nodes()) { - if (node->IsOp() && !ANAT::Agent(node).subgraph()->empty()) { + if (node->IsOp() && !Agent(node).subgraph()->empty()) { OpDesc *op_desc = node->Op(); op_desc->SetType("ngraph_engine"); CreateNgraphEngineOp(node, graph); std::unordered_set nodes2remove( - ANAT::Agent(node).subgraph()->begin(), - ANAT::Agent(node).subgraph()->end()); + Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); GraphSafeRemoveNodes(graph, nodes2remove); } @@ -79,7 +75,7 @@ void NgraphSubgraphPass::ApplyImpl(Graph *graph) const { std::unordered_set nodes2remove; for (auto *node : graph->Nodes()) { - if (node->IsOp() && ANAT::Agent(node).deleted()) { + if (node->IsOp() && Agent(node).deleted()) { nodes2remove.insert(node); } } @@ -116,7 +112,7 @@ void UpdateNgraphIO(Node *node, Graph *graph, return; } - auto &subgraph = *ANAT::Agent(node).subgraph(); + auto &subgraph = *Agent(node).subgraph(); std::unordered_set inputs; std::unordered_set outputs; for (auto *node : subgraph) { @@ -138,7 +134,7 @@ void UpdateNgraphIO(Node *node, Graph *graph, } void NgraphSubgraphPass::CreateNgraphEngineOp(Node *node, Graph *graph) const { - auto &subgraph = *ANAT::Agent(node).subgraph(); + auto &subgraph = *Agent(node).subgraph(); PADDLE_ENFORCE_NE(subgraph.empty(), true, "subgraph cannot be empty"); framework::proto::BlockDesc block_proto; diff --git a/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc b/paddle/fluid/framework/ir/subgraph_detector.cc similarity index 96% rename from paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc rename to paddle/fluid/framework/ir/subgraph_detector.cc index 064f947aaa7ca75c6497ddf76d4d78c5557fdeb8..f705fca4e8699b8923202921ab8dc374c689042f 100644 --- a/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc +++ b/paddle/fluid/framework/ir/subgraph_detector.cc @@ -1,474 +1,472 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" -#include -#include -#include -#include -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/node.h" - -DECLARE_bool(use_ngraph); - -namespace paddle { -namespace inference { -namespace analysis { - -using framework::ir::Node; - -std::pair, std::vector> -ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT - std::unordered_set nodes(graph.begin(), graph.end()); - std::unordered_set inputs; - std::unordered_set outputs; - // Input a Value, check whether its inlink is in the subgraph. - auto inlink_in_subgraph = [&](Node *n) { - for (auto *in : n->inputs) { - if (nodes.count(in)) return true; - } - return false; - }; - - for (auto &node : graph) { - for (auto *in : node->inputs) { - // The Value that is written by nodes inside a sub-graph shouldn't be the - // input of the sub-graph. - if (!nodes.count(in) && in->IsVar() && !inlink_in_subgraph(in)) { - inputs.insert(in); - } - } - for (auto *out : node->outputs) { - if (!nodes.count(out) && out->IsVar()) { - outputs.insert(out); - } - } - } - return std::make_pair(std::vector(inputs.begin(), inputs.end()), - std::vector(outputs.begin(), outputs.end())); -} - -// Filter the Intermediate results of the subgraph node. -void FilterRedundantOutputOfSubGraph(Graph *graph) { - std::vector op_nodes; - for (auto &node : TopologicalSort(*graph)) { - if (node.IsVar() || Agent(&node).deleted()) { - continue; - } - op_nodes.push_back(&node); - } - size_t op_num = op_nodes.size(); - for (size_t i = 0; i < op_num; i++) { - if (op_nodes[i]->IsOp()) continue; - std::unordered_set follow_up_input_names; - for (size_t j = i + 1; j < op_num; j++) { - for (auto *in : op_nodes[j]->inputs) { - follow_up_input_names.insert(in->Name()); - } - } - std::vector filtered_subgraph_outlinks; - for (auto *out : op_nodes[i]->outputs) { - if (follow_up_input_names.count(out->Name())) { - filtered_subgraph_outlinks.push_back(out); - } else { - Agent(out).set_deleted(true); - } - } - // The filtered_subgraph_outlinks may be empty. - op_nodes[i]->outputs = filtered_subgraph_outlinks; - } -} - -std::vector> SubgraphDetector::operator()() { - MarkNodesInsideSubGraph(); - return ExtractSubGraphs(); -} - -// Mark the output variables inside a subgraph with the func. -inline void MarkOutLinksInSubGraph(const Node *func) { - for (auto *var : func->outputs) { - Agent(var).set_marked(true); - } -} - -void SubgraphDetector::MarkNodesInsideSubGraph() { - for (auto &node : framework::ir::GraphTraits::DFS(*graph_)) { - if (node_inside_subgraph_teller_(&node)) { - Agent(&node).set_marked(true); - if (node.IsOp()) { - // If a function is inside the sub-graph, mark all the output variables - // to be inside too, so that two marked functions will be inside a same - // sub-graph, lets take a example: A_function->var->B_function, if - // A_function is marked, var should also be marked, so that B_function - // will be in the same sub-graph with A_function if B_function is - // marked. - MarkOutLinksInSubGraph(&node); - } - } - } -} - -// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node -// a's output is node b, that is a and b is in the same sub-graph. The UF -// algorithm will group them to the same cluster. -using node_map_t = std::unordered_map; -// Find the ancestor id of a node. -int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { - int tmp = id; - do { - tmp = Agent(node_map.at(tmp)).union_find_parent(); - } while (Agent(node_map.at(tmp)).union_find_parent() != tmp); - return tmp; -} -// Make this two node share the same ancestor. -// TODO(Superjom) bad performance, make a balanced tree latter. -void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { - int a_ancestor = UnionFindGetAncestor(node_map, a); - int b_ancestor = UnionFindGetAncestor(node_map, b); - Agent(node_map.at(b_ancestor)).set_union_find_parent(a_ancestor); - Agent(node_map.at(a)).set_union_find_parent(a_ancestor); - Agent(node_map.at(b)).set_union_find_parent(a_ancestor); -} - -// This is a simple representation of a graph. -// The BriefNode hold the pointer of the Node. -// This is to avoid changing the original graph -// in the process of trt graph analysis. -struct BriefNode { - explicit BriefNode(Node *n) { node = n; } - Node *node; - std::vector inlinks; - std::vector outlinks; -}; - -// Union two adjacent BriefNode. -// Suppose we have two adjacent nodes src and dst. -// We will perform the following operations: -// 1. add all inputs(except src) of dst to src inlinks. -// 2. add all outputs of dst to src outlinks. -// 3. change all the dst's inputs and outputs -// corresponding inlinks and outlinks to src node. -// 4. delete all dst's inlinks and outlinks. -void UnionContractedNodes(const std::unordered_map &node_map, - int src_id, int dst_id) { - // merge the two adjacent nodes into one node. - BriefNode *src_node = node_map.at(src_id); - BriefNode *dst_node = node_map.at(dst_id); - - std::unordered_set inputs(src_node->inlinks.begin(), - src_node->inlinks.end()); - std::unordered_set outputs; - - for (auto *n : src_node->outlinks) { - if (n != dst_node) outputs.insert(n); - } - - // Add the inlinks and outlinks of dst node to src node. - std::vector dst_in_nodes = dst_node->inlinks; - for (BriefNode *node : dst_in_nodes) { - if (node != src_node) { - inputs.insert(node); - } - } - - std::vector dst_out_nodes = dst_node->outlinks; - for (BriefNode *node : dst_out_nodes) { - outputs.insert(node); - } - -// update the dst and src node's inlinks and outlinks. -#ifdef __clang__ - src_node->inlinks = std::vector(inputs.begin(), inputs.end()); - src_node->outlinks = std::vector(outputs.begin(), outputs.end()); - dst_node->inlinks.clear(); - dst_node->outlinks.clear(); -#else - src_node->inlinks = - std::move(std::vector(inputs.begin(), inputs.end())); - src_node->outlinks = - std::move(std::vector(outputs.begin(), outputs.end())); - dst_node->inlinks.clear(); - dst_node->outlinks.clear(); -#endif - - auto inlink_or_outlink_cleaner = [&](std::vector &nodes) { - for (auto *&n : nodes) { - if (n == src_node || n == dst_node) { - n = src_node; - } - } - }; - // Change all the dst inputs and outputs corresponding inlink and - // outlink to the src node. - for (auto *node : src_node->inlinks) { - inlink_or_outlink_cleaner(node->outlinks); - } - - for (auto *node : src_node->outlinks) { - inlink_or_outlink_cleaner(node->inlinks); - } -} - -// FlexibleDFS -// If reverse is true, do reverse dfs. -// If enter func is not nullptr, calls enter(node) before visiting any children -// of node. -// If leave func not nullptr, calls leave(node) after visiting all parents of -// node. -void FlexibleDFS(const std::vector &source, bool reverse, - const std::function &enter, - const std::function &leave) { - typedef struct { - const BriefNode *node; - bool leave; - } FNode; - - std::vector stack; - for (auto &node : source) { - stack.push_back(FNode{node, false}); - } - std::unordered_set visited; - while (!stack.empty()) { - auto fnode = stack.back(); - stack.pop_back(); - - if (fnode.leave) { - if (leave && !leave(fnode.node)) return; - } - if (visited.count(fnode.node)) continue; - visited.insert(fnode.node); - - if (enter && !enter(fnode.node)) return; - - if (leave) stack.push_back(FNode{fnode.node, true}); - const std::vector iter_nodes = - reverse == true ? fnode.node->inlinks : fnode.node->outlinks; - for (const BriefNode *node : iter_nodes) { - if (!visited.count(node)) { - stack.push_back(FNode{node, false}); - } - } - } -} - -std::vector> SubgraphDetector::ExtractSubGraphs() { - // Run the Extract algorithm to find all subgraphs. - std::vector marked_nodes; - // We use brief_node_map to represent the original graph in order to avoid - // changing the original graph. - std::unordered_map brief_node_map; - - std::unordered_set valid_node_ids; - for (auto *node : graph_->Nodes()) { - valid_node_ids.insert(node->id()); - } - - for (auto &node : framework::ir::GraphTraits::TS(*graph_)) { - brief_node_map[node.id()] = new BriefNode(&node); - if (Agent(&node).marked()) { - marked_nodes.push_back(&node); - } - } - - // extract sub-graphs in the marked node set, use Union Find algorithm. - node_map_t node_map; // id to ptr - for (auto *n : marked_nodes) { - // n's parent == n.id means it is the ancestor - Agent(n).set_union_find_parent(n->id()); - node_map[n->id()] = n; - } - - // create breif node map - for (auto &itr : brief_node_map) { - for (Node *node : itr.second->node->inputs) { - if (!valid_node_ids.count(node->id())) { - LOG(INFO) << "invalid node id " << node->id(); - continue; - } - itr.second->inlinks.push_back(brief_node_map.at(node->id())); - } - - for (Node *node : itr.second->node->outputs) { - if (!valid_node_ids.count(node->id())) { - LOG(INFO) << "invalid node id " << node->id(); - continue; - } - itr.second->outlinks.push_back(brief_node_map.at(node->id())); - } - } - - for (auto &itr : brief_node_map) { - BriefNode *brief_node = itr.second; - - if (!Agent(brief_node->node).marked()) { - VLOG(4) << brief_node->node->id() << " node not a trt candidate."; - continue; - } - - // Our algorithm must guarantee that: - // 1. The graph is always directed acyclic graph(DAG). - // 2. If there is a path in the subgraph from X to Y (X and Y are both - // nodes in the subgraph), then all paths from X to Y are in the - // subgraph. - // - // In order to achieve the above guarantee. - // For adjacent nodes src -> dst. - // 1. Get all dst input nodes except src. - // 2. Reverse DFS from those input nodes - // 3. If there is a path from input nodes to src, - // then the src and dst nodes can not be fused into one node, - // otherwise it can be done. - - while (true) { - std::unordered_set contract_nodes; - for (auto *out : brief_node->outlinks) { - // must be an trt candidate - if (!Agent(out->node).marked()) continue; - // get all dst input nodes except src. - std::vector source_nodes; - for (auto *n : out->inlinks) { - if (n != brief_node) { - source_nodes.push_back(n); - } - } - - // Reverse DFS from the source_nodes. - bool have_excess_path = false; - FlexibleDFS(source_nodes, true, nullptr, - [&have_excess_path, brief_node](const BriefNode *n) { - if (n == brief_node) { - have_excess_path = true; - return false; - } - return true; - }); - if (have_excess_path) continue; - contract_nodes.insert(out); - } - if (contract_nodes.empty()) break; - - for (auto dst_node : contract_nodes) { - UnionFindCombine(node_map, brief_node->node->id(), - dst_node->node->id()); - UnionContractedNodes(brief_node_map, brief_node->node->id(), - dst_node->node->id()); - } - } - } - - std::unordered_map> clusters; - for (auto *n : marked_nodes) { - if (n->IsOp()) { - clusters[UnionFindGetAncestor(node_map, Agent(n).union_find_parent())] - .push_back(n); - } - } - std::vector> result; - std::for_each(clusters.begin(), clusters.end(), - [&](const decltype(clusters)::value_type &it) { - result.push_back(it.second); - }); - - return result; -} - -void SubGraphFuser::operator()() { ReplaceNodesWithSubGraphs(); } - -void RemoveIntermediateOutputInSubgraph(const std::vector &subgraph, - Graph *graph, - std::vector *outputs) { - std::unordered_set subgraph_set(subgraph.begin(), subgraph.end()); - std::unordered_set valid_output; - - for (auto *output : *outputs) { - int num_used = 0; - for (auto *node : output->outputs) { - if (!subgraph_set.count(node)) ++num_used; - if (num_used > 0) valid_output.insert(output); - } - } - - // In use for ngraph subgraph pass for parallel executor, - // this will remove all nodes, bypass this and let ngraph - // subgraph pass to process outputs - if (FLAGS_use_ngraph && valid_output.size() == 0) return; - - outputs->assign(valid_output.begin(), valid_output.end()); -} - -void DetachDeletedNodes(framework::ir::Graph *graph) { - std::unordered_set nodes; - for (auto *node : graph->Nodes()) { - if (Agent(node).deleted()) { - node->inputs.clear(); - node->outputs.clear(); - } - } -} - -void SubGraphFuser::ReplaceNodesWithSubGraphs() { - auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)(); - for (auto &subgraph : subgraphs) { - if (subgraph.size() <= (size_t)min_subgraph_size_) continue; - std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); - // replace this sub-graph with the first node. Two steps: 1. Create a Block - // Node that contains this subgraph 2. Mark the nodes inside the sub-graph - // as deleted. 3. Replace the deleted node with the new Block Node. - framework::OpDesc empty_desc; - empty_desc.SetType(name_); - auto *block_node = graph_->CreateOpNode(&empty_desc); - Agent(block_node).set_subgraph({}); - auto io = ExtractInputAndOutputOfSubGraph(subgraph); - block_node->inputs = std::move(io.first); - block_node->outputs = std::move(io.second); - - RemoveIntermediateOutputInSubgraph(subgraph, graph_, &block_node->outputs); - - for (auto *node : subgraph) { - // TODO(Superjomn) need a unified mechanism to treat deleted node in each - // pass. - Agent(node).set_deleted(true); - Agent(block_node).subgraph()->push_back(node); - } - - // Change all the sub-graph's inputs and outputs corresponding inlink and - // outlink to this sub-graph node. - auto inlink_or_outlink_cleaner = [&](std::vector &nodes) { - for (auto *&n : nodes) { - if (subgraph_uniq.count(n)) { - n = block_node; - } - } - std::unordered_set uniq(nodes.begin(), nodes.end()); - nodes.assign(uniq.begin(), uniq.end()); - }; - for (auto *i : block_node->inputs) { - inlink_or_outlink_cleaner(i->outputs); - } - for (auto *&o : block_node->outputs) { - inlink_or_outlink_cleaner(o->inputs); - } - } - // DetachDeletedNodes(graph_); - FilterRedundantOutputOfSubGraph(graph_); -} - -inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) { - return node.inputs.size() == n; -} - -} // namespace analysis -} // namespace inference -} // namespace paddle +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/ir/subgraph_detector.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/node.h" + +DECLARE_bool(use_ngraph); + +namespace paddle { +namespace framework { +namespace ir { + +std::pair, std::vector> +ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT + std::unordered_set nodes(graph.begin(), graph.end()); + std::unordered_set inputs; + std::unordered_set outputs; + // Input a Value, check whether its inlink is in the subgraph. + auto inlink_in_subgraph = [&](Node *n) { + for (auto *in : n->inputs) { + if (nodes.count(in)) return true; + } + return false; + }; + + for (auto &node : graph) { + for (auto *in : node->inputs) { + // The Value that is written by nodes inside a sub-graph shouldn't be the + // input of the sub-graph. + if (!nodes.count(in) && in->IsVar() && !inlink_in_subgraph(in)) { + inputs.insert(in); + } + } + for (auto *out : node->outputs) { + if (!nodes.count(out) && out->IsVar()) { + outputs.insert(out); + } + } + } + return std::make_pair(std::vector(inputs.begin(), inputs.end()), + std::vector(outputs.begin(), outputs.end())); +} + +// Filter the Intermediate results of the subgraph node. +void FilterRedundantOutputOfSubGraph(Graph *graph) { + std::vector op_nodes; + for (auto &node : TopologicalSort(*graph)) { + if (node.IsVar() || Agent(&node).deleted()) { + continue; + } + op_nodes.push_back(&node); + } + size_t op_num = op_nodes.size(); + for (size_t i = 0; i < op_num; i++) { + if (op_nodes[i]->IsOp()) continue; + std::unordered_set follow_up_input_names; + for (size_t j = i + 1; j < op_num; j++) { + for (auto *in : op_nodes[j]->inputs) { + follow_up_input_names.insert(in->Name()); + } + } + std::vector filtered_subgraph_outlinks; + for (auto *out : op_nodes[i]->outputs) { + if (follow_up_input_names.count(out->Name())) { + filtered_subgraph_outlinks.push_back(out); + } else { + Agent(out).set_deleted(true); + } + } + // The filtered_subgraph_outlinks may be empty. + op_nodes[i]->outputs = filtered_subgraph_outlinks; + } +} + +std::vector> SubgraphDetector::operator()() { + MarkNodesInsideSubGraph(); + return ExtractSubGraphs(); +} + +// Mark the output variables inside a subgraph with the func. +inline void MarkOutLinksInSubGraph(const Node *func) { + for (auto *var : func->outputs) { + Agent(var).set_marked(true); + } +} + +void SubgraphDetector::MarkNodesInsideSubGraph() { + for (auto &node : framework::ir::GraphTraits::DFS(*graph_)) { + if (node_inside_subgraph_teller_(&node)) { + Agent(&node).set_marked(true); + if (node.IsOp()) { + // If a function is inside the sub-graph, mark all the output variables + // to be inside too, so that two marked functions will be inside a same + // sub-graph, lets take a example: A_function->var->B_function, if + // A_function is marked, var should also be marked, so that B_function + // will be in the same sub-graph with A_function if B_function is + // marked. + MarkOutLinksInSubGraph(&node); + } + } + } +} + +// Use the Union Find(UF) algorithm to find fully connected sub-graphs, if node +// a's output is node b, that is a and b is in the same sub-graph. The UF +// algorithm will group them to the same cluster. +using node_map_t = std::unordered_map; +// Find the ancestor id of a node. +int UnionFindGetAncestor(const node_map_t &node_map, size_t id) { + int tmp = id; + do { + tmp = Agent(node_map.at(tmp)).union_find_parent(); + } while (Agent(node_map.at(tmp)).union_find_parent() != tmp); + return tmp; +} +// Make this two node share the same ancestor. +// TODO(Superjom) bad performance, make a balanced tree latter. +void UnionFindCombine(const node_map_t &node_map, size_t a, size_t b) { + int a_ancestor = UnionFindGetAncestor(node_map, a); + int b_ancestor = UnionFindGetAncestor(node_map, b); + Agent(node_map.at(b_ancestor)).set_union_find_parent(a_ancestor); + Agent(node_map.at(a)).set_union_find_parent(a_ancestor); + Agent(node_map.at(b)).set_union_find_parent(a_ancestor); +} + +// This is a simple representation of a graph. +// The BriefNode hold the pointer of the Node. +// This is to avoid changing the original graph +// in the process of trt graph analysis. +struct BriefNode { + explicit BriefNode(Node *n) { node = n; } + Node *node; + std::vector inlinks; + std::vector outlinks; +}; + +// Union two adjacent BriefNode. +// Suppose we have two adjacent nodes src and dst. +// We will perform the following operations: +// 1. add all inputs(except src) of dst to src inlinks. +// 2. add all outputs of dst to src outlinks. +// 3. change all the dst's inputs and outputs +// corresponding inlinks and outlinks to src node. +// 4. delete all dst's inlinks and outlinks. +void UnionContractedNodes(const std::unordered_map &node_map, + int src_id, int dst_id) { + // merge the two adjacent nodes into one node. + BriefNode *src_node = node_map.at(src_id); + BriefNode *dst_node = node_map.at(dst_id); + + std::unordered_set inputs(src_node->inlinks.begin(), + src_node->inlinks.end()); + std::unordered_set outputs; + + for (auto *n : src_node->outlinks) { + if (n != dst_node) outputs.insert(n); + } + + // Add the inlinks and outlinks of dst node to src node. + std::vector dst_in_nodes = dst_node->inlinks; + for (BriefNode *node : dst_in_nodes) { + if (node != src_node) { + inputs.insert(node); + } + } + + std::vector dst_out_nodes = dst_node->outlinks; + for (BriefNode *node : dst_out_nodes) { + outputs.insert(node); + } + +// update the dst and src node's inlinks and outlinks. +#ifdef __clang__ + src_node->inlinks = std::vector(inputs.begin(), inputs.end()); + src_node->outlinks = std::vector(outputs.begin(), outputs.end()); + dst_node->inlinks.clear(); + dst_node->outlinks.clear(); +#else + src_node->inlinks = + std::move(std::vector(inputs.begin(), inputs.end())); + src_node->outlinks = + std::move(std::vector(outputs.begin(), outputs.end())); + dst_node->inlinks.clear(); + dst_node->outlinks.clear(); +#endif + + auto inlink_or_outlink_cleaner = [&](std::vector &nodes) { + for (auto *&n : nodes) { + if (n == src_node || n == dst_node) { + n = src_node; + } + } + }; + // Change all the dst inputs and outputs corresponding inlink and + // outlink to the src node. + for (auto *node : src_node->inlinks) { + inlink_or_outlink_cleaner(node->outlinks); + } + + for (auto *node : src_node->outlinks) { + inlink_or_outlink_cleaner(node->inlinks); + } +} + +// FlexibleDFS +// If reverse is true, do reverse dfs. +// If enter func is not nullptr, calls enter(node) before visiting any children +// of node. +// If leave func not nullptr, calls leave(node) after visiting all parents of +// node. +void FlexibleDFS(const std::vector &source, bool reverse, + const std::function &enter, + const std::function &leave) { + typedef struct { + const BriefNode *node; + bool leave; + } FNode; + + std::vector stack; + for (auto &node : source) { + stack.push_back(FNode{node, false}); + } + std::unordered_set visited; + while (!stack.empty()) { + auto fnode = stack.back(); + stack.pop_back(); + + if (fnode.leave) { + if (leave && !leave(fnode.node)) return; + } + if (visited.count(fnode.node)) continue; + visited.insert(fnode.node); + + if (enter && !enter(fnode.node)) return; + + if (leave) stack.push_back(FNode{fnode.node, true}); + const std::vector iter_nodes = + reverse == true ? fnode.node->inlinks : fnode.node->outlinks; + for (const BriefNode *node : iter_nodes) { + if (!visited.count(node)) { + stack.push_back(FNode{node, false}); + } + } + } +} + +std::vector> SubgraphDetector::ExtractSubGraphs() { + // Run the Extract algorithm to find all subgraphs. + std::vector marked_nodes; + // We use brief_node_map to represent the original graph in order to avoid + // changing the original graph. + std::unordered_map brief_node_map; + + std::unordered_set valid_node_ids; + for (auto *node : graph_->Nodes()) { + valid_node_ids.insert(node->id()); + } + + for (auto &node : framework::ir::GraphTraits::TS(*graph_)) { + brief_node_map[node.id()] = new BriefNode(&node); + if (Agent(&node).marked()) { + marked_nodes.push_back(&node); + } + } + + // extract sub-graphs in the marked node set, use Union Find algorithm. + node_map_t node_map; // id to ptr + for (auto *n : marked_nodes) { + // n's parent == n.id means it is the ancestor + Agent(n).set_union_find_parent(n->id()); + node_map[n->id()] = n; + } + + // create breif node map + for (auto &itr : brief_node_map) { + for (Node *node : itr.second->node->inputs) { + if (!valid_node_ids.count(node->id())) { + LOG(INFO) << "invalid node id " << node->id(); + continue; + } + itr.second->inlinks.push_back(brief_node_map.at(node->id())); + } + + for (Node *node : itr.second->node->outputs) { + if (!valid_node_ids.count(node->id())) { + LOG(INFO) << "invalid node id " << node->id(); + continue; + } + itr.second->outlinks.push_back(brief_node_map.at(node->id())); + } + } + + for (auto &itr : brief_node_map) { + BriefNode *brief_node = itr.second; + + if (!Agent(brief_node->node).marked()) { + VLOG(4) << brief_node->node->id() << " node not a trt candidate."; + continue; + } + + // Our algorithm must guarantee that: + // 1. The graph is always directed acyclic graph(DAG). + // 2. If there is a path in the subgraph from X to Y (X and Y are both + // nodes in the subgraph), then all paths from X to Y are in the + // subgraph. + // + // In order to achieve the above guarantee. + // For adjacent nodes src -> dst. + // 1. Get all dst input nodes except src. + // 2. Reverse DFS from those input nodes + // 3. If there is a path from input nodes to src, + // then the src and dst nodes can not be fused into one node, + // otherwise it can be done. + + while (true) { + std::unordered_set contract_nodes; + for (auto *out : brief_node->outlinks) { + // must be an trt candidate + if (!Agent(out->node).marked()) continue; + // get all dst input nodes except src. + std::vector source_nodes; + for (auto *n : out->inlinks) { + if (n != brief_node) { + source_nodes.push_back(n); + } + } + + // Reverse DFS from the source_nodes. + bool have_excess_path = false; + FlexibleDFS(source_nodes, true, nullptr, + [&have_excess_path, brief_node](const BriefNode *n) { + if (n == brief_node) { + have_excess_path = true; + return false; + } + return true; + }); + if (have_excess_path) continue; + contract_nodes.insert(out); + } + if (contract_nodes.empty()) break; + + for (auto dst_node : contract_nodes) { + UnionFindCombine(node_map, brief_node->node->id(), + dst_node->node->id()); + UnionContractedNodes(brief_node_map, brief_node->node->id(), + dst_node->node->id()); + } + } + } + + std::unordered_map> clusters; + for (auto *n : marked_nodes) { + if (n->IsOp()) { + clusters[UnionFindGetAncestor(node_map, Agent(n).union_find_parent())] + .push_back(n); + } + } + std::vector> result; + std::for_each(clusters.begin(), clusters.end(), + [&](const decltype(clusters)::value_type &it) { + result.push_back(it.second); + }); + + return result; +} + +void SubGraphFuser::operator()() { ReplaceNodesWithSubGraphs(); } + +void RemoveIntermediateOutputInSubgraph(const std::vector &subgraph, + Graph *graph, + std::vector *outputs) { + std::unordered_set subgraph_set(subgraph.begin(), subgraph.end()); + std::unordered_set valid_output; + + for (auto *output : *outputs) { + int num_used = 0; + for (auto *node : output->outputs) { + if (!subgraph_set.count(node)) ++num_used; + if (num_used > 0) valid_output.insert(output); + } + } + + // In use for ngraph subgraph pass for parallel executor, + // this will remove all nodes, bypass this and let ngraph + // subgraph pass to process outputs + if (FLAGS_use_ngraph && valid_output.size() == 0) return; + + outputs->assign(valid_output.begin(), valid_output.end()); +} + +void DetachDeletedNodes(framework::ir::Graph *graph) { + std::unordered_set nodes; + for (auto *node : graph->Nodes()) { + if (Agent(node).deleted()) { + node->inputs.clear(); + node->outputs.clear(); + } + } +} + +void SubGraphFuser::ReplaceNodesWithSubGraphs() { + auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)(); + for (auto &subgraph : subgraphs) { + if (subgraph.size() <= (size_t)min_subgraph_size_) continue; + std::unordered_set subgraph_uniq(subgraph.begin(), subgraph.end()); + // replace this sub-graph with the first node. Two steps: 1. Create a Block + // Node that contains this subgraph 2. Mark the nodes inside the sub-graph + // as deleted. 3. Replace the deleted node with the new Block Node. + framework::OpDesc empty_desc; + empty_desc.SetType(name_); + auto *block_node = graph_->CreateOpNode(&empty_desc); + Agent(block_node).set_subgraph({}); + auto io = ExtractInputAndOutputOfSubGraph(subgraph); + block_node->inputs = std::move(io.first); + block_node->outputs = std::move(io.second); + + RemoveIntermediateOutputInSubgraph(subgraph, graph_, &block_node->outputs); + + for (auto *node : subgraph) { + // TODO(Superjomn) need a unified mechanism to treat deleted node in each + // pass. + Agent(node).set_deleted(true); + Agent(block_node).subgraph()->push_back(node); + } + + // Change all the sub-graph's inputs and outputs corresponding inlink and + // outlink to this sub-graph node. + auto inlink_or_outlink_cleaner = [&](std::vector &nodes) { + for (auto *&n : nodes) { + if (subgraph_uniq.count(n)) { + n = block_node; + } + } + std::unordered_set uniq(nodes.begin(), nodes.end()); + nodes.assign(uniq.begin(), uniq.end()); + }; + for (auto *i : block_node->inputs) { + inlink_or_outlink_cleaner(i->outputs); + } + for (auto *&o : block_node->outputs) { + inlink_or_outlink_cleaner(o->inputs); + } + } + // DetachDeletedNodes(graph_); + FilterRedundantOutputOfSubGraph(graph_); +} + +inline bool CheckNodeIndegreeEquals(const Node &node, size_t n) { + return node.inputs.size() == n; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h b/paddle/fluid/framework/ir/subgraph_detector.h similarity index 75% rename from paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h rename to paddle/fluid/framework/ir/subgraph_detector.h index 26201541f67e3bf8546bc38dbf6823a3dc05a3ee..d329198eeab704a75a586d231bc77bd6033e1ec4 100644 --- a/paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h +++ b/paddle/fluid/framework/ir/subgraph_detector.h @@ -1,160 +1,154 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -/* - * This file defines the the class to partition a graph. - */ - -#pragma once - -#include -#include -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_traits.h" -#include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/inference/analysis/argument.h" -#include "paddle/fluid/inference/analysis/helper.h" - -namespace paddle { -namespace inference { -namespace analysis { - -using framework::ir::Graph; -using framework::ir::NodesTSIterator; - -const char kIsFunctionNode[] = "__is_function_node__"; -const char kFunctionNodeSubGraph[] = "__function_node_sub_graph__"; -const char kSubgraphSplitterMarkerAttrName[] = - "_sub_graph_splitter_inside_sub_graph"; - -/* - * Detect the nodes in a sub-graph that meet some conditions. This class doesn't - * modify the graph. - */ -class SubgraphDetector { - public: - // Tell whether a node is inside a sub-graph. - using NodeInsideSubgraphTeller = - std::function; - - SubgraphDetector(Graph *graph, const NodeInsideSubgraphTeller &teller) - : graph_(graph), node_inside_subgraph_teller_(teller) {} - - std::vector> operator()(); - - protected: - // Mark the nodes inside the accepted sub-graph using - // node_inside_subgraph_teller. - void MarkNodesInsideSubGraph(); - - // Merge the marked nodes into sub-graphs and return the sub-graphs. - std::vector> ExtractSubGraphs(); - - private: - Graph *graph_; - NodeInsideSubgraphTeller node_inside_subgraph_teller_; -}; - -/* - * SubGraphFuser - Replace some nodes with the sub-graph node they are inside. - * To some extent, the TensorRT engine is just a fusion op for a model. - */ -class SubGraphFuser { - public: - using NodeInsideSubgraphTeller = SubgraphDetector::NodeInsideSubgraphTeller; - - SubGraphFuser(Graph *graph, const NodeInsideSubgraphTeller &teller, - int min_subgraph_size, std::string name = "anakin_engine") - : graph_(graph), - node_inside_subgraph_teller_(teller), - min_subgraph_size_{min_subgraph_size}, - name_{name} {} - - // The main method which run all the logic. - void operator()(); - - protected: - // Remove the nodes inside sub-graphs and replace with the SubGraphNode. - void ReplaceNodesWithSubGraphs(); - - private: - Graph *graph_; - NodeInsideSubgraphTeller node_inside_subgraph_teller_; - int min_subgraph_size_; - const std::string name_; -}; - -struct NodeWrapper { - bool deleted{false}; - bool marked{false}; - int union_find_parent{-1}; - std::vector subgraph; -}; - -/* - * ir::Node agent for subgraph detector. - */ -struct Agent { - explicit Agent(framework::ir::Node *x) : x_(x) {} - - NodeWrapper &wrapper() { - if (!x_->IsWrappedBy()) { - x_->WrappedBy(new NodeWrapper); - } - return x_->template Wrapper(); - } - - bool deleted() { return wrapper().deleted; } - void set_deleted(bool x) { wrapper().deleted = x; } - - bool marked() { return wrapper().marked; } - void set_marked(bool x) { wrapper().marked = x; } - - void set_subgraph(const std::vector &x) { - wrapper().subgraph = x; - } - - int union_find_parent() { return wrapper().union_find_parent; } - void set_union_find_parent(int v) { wrapper().union_find_parent = v; } - - std::vector *subgraph() { return &wrapper().subgraph; } - std::vector &inputs() { return x_->inputs; } - std::vector &outputs() { return x_->outputs; } - - private: - framework::ir::Node *x_; -}; - -// The nodes those have no input will be treated as start points. -static std::vector ExtractStartPoints(const Graph &g) { - std::vector result; - for (auto *node : g.Nodes()) { - if (node->inputs.empty()) { - result.push_back(node); - } - } - return result; -} - -static iterator_range TopologicalSort(const Graph &g) { - auto start_points = ExtractStartPoints(g); - PADDLE_ENFORCE(!start_points.empty()); - NodesTSIterator x(start_points); - return iterator_range(NodesTSIterator(start_points), - NodesTSIterator()); -} - -} // namespace analysis -} // namespace inference -} // namespace paddle +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace ir { + +const char kIsFunctionNode[] = "__is_function_node__"; +const char kFunctionNodeSubGraph[] = "__function_node_sub_graph__"; +const char kSubgraphSplitterMarkerAttrName[] = + "_sub_graph_splitter_inside_sub_graph"; + +/* + * Detect the nodes in a sub-graph that meet some conditions. This class doesn't + * modify the graph. + */ +class SubgraphDetector { + public: + // Tell whether a node is inside a sub-graph. + using NodeInsideSubgraphTeller = std::function; + + SubgraphDetector(Graph *graph, const NodeInsideSubgraphTeller &teller) + : graph_(graph), node_inside_subgraph_teller_(teller) {} + + std::vector> operator()(); + + protected: + // Mark the nodes inside the accepted sub-graph using + // node_inside_subgraph_teller. + void MarkNodesInsideSubGraph(); + + // Merge the marked nodes into sub-graphs and return the sub-graphs. + std::vector> ExtractSubGraphs(); + + private: + Graph *graph_; + NodeInsideSubgraphTeller node_inside_subgraph_teller_; +}; + +/* + * SubGraphFuser - Replace some nodes with the sub-graph node they are inside. + * To some extent, the TensorRT engine is just a fusion op for a model. + */ +class SubGraphFuser { + public: + using NodeInsideSubgraphTeller = SubgraphDetector::NodeInsideSubgraphTeller; + + SubGraphFuser(Graph *graph, const NodeInsideSubgraphTeller &teller, + int min_subgraph_size, std::string name = "anakin_engine") + : graph_(graph), + node_inside_subgraph_teller_(teller), + min_subgraph_size_{min_subgraph_size}, + name_{name} {} + + // The main method which run all the logic. + void operator()(); + + protected: + // Remove the nodes inside sub-graphs and replace with the SubGraphNode. + void ReplaceNodesWithSubGraphs(); + + private: + Graph *graph_; + NodeInsideSubgraphTeller node_inside_subgraph_teller_; + int min_subgraph_size_; + const std::string name_; +}; + +struct NodeWrapper { + bool deleted{false}; + bool marked{false}; + int union_find_parent{-1}; + std::vector subgraph; +}; + +/* + * ir::Node agent for subgraph detector. + */ +struct Agent { + explicit Agent(Node *x) : x_(x) {} + + NodeWrapper &wrapper() { + if (!x_->IsWrappedBy()) { + x_->WrappedBy(new NodeWrapper); + } + return x_->template Wrapper(); + } + + bool deleted() { return wrapper().deleted; } + void set_deleted(bool x) { wrapper().deleted = x; } + + bool marked() { return wrapper().marked; } + void set_marked(bool x) { wrapper().marked = x; } + + void set_subgraph(const std::vector &x) { + wrapper().subgraph = x; + } + + int union_find_parent() { return wrapper().union_find_parent; } + void set_union_find_parent(int v) { wrapper().union_find_parent = v; } + + std::vector *subgraph() { return &wrapper().subgraph; } + std::vector &inputs() { return x_->inputs; } + std::vector &outputs() { return x_->outputs; } + + private: + Node *x_; +}; + +// The nodes those have no input will be treated as start points. +static std::vector ExtractStartPoints(const Graph &g) { + std::vector result; + for (auto *node : g.Nodes()) { + if (node->inputs.empty()) { + result.push_back(node); + } + } + return result; +} + +static iterator_range TopologicalSort(const Graph &g) { + auto start_points = ExtractStartPoints(g); + PADDLE_ENFORCE_GT( + start_points.size(), 0U, + platform::errors::InvalidArgument( + "Expected the number of graph's start points >= 1. Expected %d.", + start_points.size())); + NodesTSIterator x(start_points); + return iterator_range(NodesTSIterator(start_points), + NodesTSIterator()); +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index c8486f5151ca0f905c175f3d19c3bebf248dda71..a1f8ff478011c3b5baeb30cd9e3692c376bd7b80 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -24,7 +24,6 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/argument.h" -#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { diff --git a/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt b/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt index ddadbc6df4aa3f95b271b011edb85a8d7077796f..3a76bb27482db779e645650cbc52e0f6d61121ba 100644 --- a/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/ir_passes/CMakeLists.txt @@ -1,13 +1,10 @@ -cc_library(subgraph_detector SRCS subgraph_detector.cc subgraph_util.cc DEPS proto_desc) -if(WITH_TESTING) - add_dependencies(subgraph_detector gtest) -endif() +cc_library(subgraph_util SRCS subgraph_util.cc DEPS subgraph_detector) if (WITH_GPU AND TENSORRT_FOUND) - cc_library(tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector tensorrt_op_teller) + cc_library(tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_util tensorrt_op_teller) set(analysis_deps ${analysis_deps} - subgraph_detector tensorrt_subgraph_pass + subgraph_util tensorrt_subgraph_pass CACHE INTERNAL "") set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h) @@ -16,10 +13,10 @@ if (WITH_GPU AND TENSORRT_FOUND) endif() if (ANAKIN_SUBGRAPH) - cc_library(anakin_subgraph_pass SRCS anakin_subgraph_pass.cc DEPS subgraph_detector anakin_op_teller) + cc_library(anakin_subgraph_pass SRCS anakin_subgraph_pass.cc DEPS subgraph_util anakin_op_teller) set(analysis_deps ${analysis_deps} - subgraph_detector anakin_subgraph_pass + subgraph_util anakin_subgraph_pass CACHE INTERNAL "") set(pass_file ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h) diff --git a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc index a6c6f33cf779f6117d0dda9a9eca279bd846ac84..b27896f98f78774ec9a9caa5809351a31347eeaf 100644 --- a/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc @@ -22,11 +22,11 @@ #include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/inference/anakin/convert/op_converter.h" #include "paddle/fluid/inference/anakin/op_teller.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.h" -#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { @@ -50,7 +50,7 @@ void analysis::AnakinSubgraphPass::ApplyImpl( return anakin::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); }; - SubGraphFuser fuser(graph, teller, 6 /* min_subgraph_size */); + framework::ir::SubGraphFuser fuser(graph, teller, 6 /* min_subgraph_size */); fuser(); std::vector graph_param_names = @@ -61,17 +61,18 @@ void analysis::AnakinSubgraphPass::ApplyImpl( std::vector repetitive_params; for (auto *node : graph->Nodes()) { - if (node->IsOp() && !Agent(node).subgraph()->empty()) { + if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) { CreateAnakinOp(node, graph, graph_param_names, &repetitive_params); std::unordered_set nodes2remove( - Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); + framework::ir::Agent(node).subgraph()->begin(), + framework::ir::Agent(node).subgraph()->end()); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); } } std::unordered_set nodes2remove; for (auto *node : graph->Nodes()) { - if (node->IsOp() && Agent(node).deleted()) { + if (node->IsOp() && framework::ir::Agent(node).deleted()) { nodes2remove.insert(node); } } @@ -96,11 +97,11 @@ std::string GenerateAnakinEngineKey(const std::set &engine_inputs, } void AnakinSubgraphPass::CreateAnakinOp( - framework::ir::Node *node, Graph *graph, + framework::ir::Node *node, framework::ir::Graph *graph, const std::vector &graph_params, std::vector *repetitive_params) const { auto *op_desc = node->Op(); - auto &subgraph = *Agent(node).subgraph(); + auto &subgraph = *framework::ir::Agent(node).subgraph(); PADDLE_ENFORCE(!subgraph.empty()); framework::ProgramDesc *program_desc = @@ -164,7 +165,7 @@ void AnakinSubgraphPass::CreateAnakinOp( graph_var_map[node->Name()] = node; } } - auto &subgraph_nodes = *Agent(node).subgraph(); + auto &subgraph_nodes = *framework::ir::Agent(node).subgraph(); // The following procedure is used to rename all the intermediate // variables and the output variables of the subgraph. diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index a173c899db69c4ac97a833893a4dc625acad5e6f..38106141b6950ac04660066ab7ed0f612bcb7173 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -17,8 +17,8 @@ #include #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/inference/analysis/helper.h" -#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h" #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" @@ -40,9 +40,9 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); }; - SubGraphFuser fuser(graph, teller, - Get("min_subgraph_size") /*min subgraph size*/, - "tensorrt_engine"); + framework::ir::SubGraphFuser fuser( + graph, teller, Get("min_subgraph_size") /*min subgraph size*/, + "tensorrt_engine"); fuser(); std::vector graph_param_names = @@ -52,18 +52,19 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( std::vector repetitive_params; for (auto *node : graph->Nodes()) { - if (node->IsOp() && !Agent(node).subgraph()->empty()) { + if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) { CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params); std::unordered_set nodes2remove( - Agent(node).subgraph()->begin(), Agent(node).subgraph()->end()); + framework::ir::Agent(node).subgraph()->begin(), + framework::ir::Agent(node).subgraph()->end()); framework::ir::GraphSafeRemoveNodes(graph, nodes2remove); } } std::unordered_set nodes2remove; for (auto *node : graph->Nodes()) { - if (node->IsOp() && Agent(node).deleted()) { + if (node->IsOp() && framework::ir::Agent(node).deleted()) { nodes2remove.insert(node); } } @@ -88,11 +89,11 @@ std::string GenerateEngineKey(const std::set &engine_inputs, } void TensorRtSubgraphPass::CreateTensorRTOp( - framework::ir::Node *node, Graph *graph, + framework::ir::Node *node, framework::ir::Graph *graph, const std::vector &graph_params, std::vector *repetitive_params) const { auto *op_desc = node->Op(); - auto &subgraph = *Agent(node).subgraph(); + auto &subgraph = *framework::ir::Agent(node).subgraph(); PADDLE_ENFORCE(!subgraph.empty()); framework::ProgramDesc *program_desc = @@ -161,7 +162,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( if (precision_mode == AnalysisConfig::Precision::kHalf) enable_fp16 = true; auto enable_int8 = Get("enable_int8"); auto use_calib_mode = Get("use_calib_mode"); - auto &subgraph_nodes = *Agent(node).subgraph(); + auto &subgraph_nodes = *framework::ir::Agent(node).subgraph(); // The following procedure is used to rename all the intermediate // variables and the output variables of the subgraph.