From 4623a9b9351e9bd8fbd5d90501813bb36c018191 Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Mon, 16 Mar 2020 09:25:05 +0800 Subject: [PATCH] split subgraph by config file (#3157) --- lite/core/mir/subgraph/subgraph_detector.cc | 78 ++++++++++++++++++- lite/core/mir/subgraph/subgraph_detector.h | 5 ++ .../mir/subgraph/subgraph_detector_test.cc | 2 +- lite/utils/env.h | 3 + 4 files changed, 86 insertions(+), 2 deletions(-) diff --git a/lite/core/mir/subgraph/subgraph_detector.cc b/lite/core/mir/subgraph/subgraph_detector.cc index 65fb11ff2c..994f346ced 100644 --- a/lite/core/mir/subgraph/subgraph_detector.cc +++ b/lite/core/mir/subgraph/subgraph_detector.cc @@ -22,6 +22,9 @@ #include "lite/core/mir/pass_registry.h" #include "lite/core/mir/pattern_matcher.h" #include "lite/operators/subgraph_op.h" +#include "lite/utils/env.h" +#include "lite/utils/io.h" +#include "lite/utils/string.h" namespace paddle { namespace lite { @@ -209,8 +212,81 @@ void SubgraphDetector::FlexibleDFS( } } +std::unordered_set SubgraphDetector::GetExcludedNodesFromConfigFile() { + // get exclude nodes from config file + std::unordered_set excluded_nodes; + std::string config_file_path = + GetStringFromEnv(SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE); + if (!IsFileExists(config_file_path)) { + return excluded_nodes; + } + std::vector lines = ReadLines(config_file_path); + + for (std::string line : lines) { + std::vector node_info = Split(line, ":"); + std::string op_type = node_info.at(0); + std::vector in_vars_name; + if (node_info.size() > 1) { + in_vars_name = Split(node_info.at(1), ","); + } + std::vector out_vars_name; + if (node_info.size() > 2) { + out_vars_name = Split(node_info.at(2), ","); + } + + for (auto &node : graph_->mutable_nodes()) { + if (node.IsArg()) continue; + auto stmt = node.stmt(); + if (op_type != stmt->op_type()) continue; + auto in_nodes = node.inlinks; + auto out_nodes = node.outlinks; + if (in_vars_name.size() > in_nodes.size() || + out_vars_name.size() > out_nodes.size()) { + continue; + } + + bool matched = true; + + for (auto in_var_name : in_vars_name) { + bool find_var = false; + for (auto *in_node : in_nodes) { + if (in_node->arg()->name == in_var_name) { + find_var = true; + break; + } + } + if (!find_var) { + matched = false; + break; + } + } + + for (auto out_var_name : out_vars_name) { + bool find_var = false; + for (auto *out_node : out_nodes) { + if (out_node->arg()->name == out_var_name) { + find_var = true; + break; + } + } + if (!find_var) { + matched = false; + break; + } + } + + if (matched) { + excluded_nodes.insert(&node); + } + } + } + + return excluded_nodes; +} + void SubgraphDetector::InitNodes(node_map_t *nodes) { // Initialize and mark the subgraph detector nodes based on teller. + std::unordered_set excluded_nodes = GetExcludedNodesFromConfigFile(); for (auto &it : *nodes) { for (auto &in_node : it.first->inlinks) { it.second->inlinks.push_back((*nodes)[in_node]); @@ -218,7 +294,7 @@ void SubgraphDetector::InitNodes(node_map_t *nodes) { for (auto &out_node : it.first->outlinks) { it.second->outlinks.push_back((*nodes)[out_node]); } - if (teller_(it.first)) { + if (teller_(it.first) && excluded_nodes.count(it.first) == 0) { it.second->marked = true; if (it.first->IsStmt()) { // If a function is inside the subgraph, mark all the output variables diff --git a/lite/core/mir/subgraph/subgraph_detector.h b/lite/core/mir/subgraph/subgraph_detector.h index b6873655e9..567f2446a2 100644 --- a/lite/core/mir/subgraph/subgraph_detector.h +++ b/lite/core/mir/subgraph/subgraph_detector.h @@ -63,6 +63,7 @@ class SubgraphDetector { node_dat_t* UnionFindAncestor(); void UnionFindCombine(node_dat_t* candidate); }; + SubgraphDetector(SSAGraph* graph, const SubgraphTeller& teller) : graph_(graph), teller_(teller) {} std::vector> operator()(); @@ -71,7 +72,11 @@ class SubgraphDetector { bool reverse, const std::function& enter, const std::function& leave); + + std::unordered_set GetExcludedNodesFromConfigFile(); + void InitNodes(node_map_t* nodes); + std::vector> ExtractSubgraphs(node_map_t* nodes); protected: diff --git a/lite/core/mir/subgraph/subgraph_detector_test.cc b/lite/core/mir/subgraph/subgraph_detector_test.cc index 3b0d7c5cd5..e96a080d57 100644 --- a/lite/core/mir/subgraph/subgraph_detector_test.cc +++ b/lite/core/mir/subgraph/subgraph_detector_test.cc @@ -220,8 +220,8 @@ TEST(Subgraph, detect_custom_model) { }; std::vector> subgraphs = mir::SubgraphDetector(graph.get(), teller)(); - ASSERT_EQ(subgraphs.size(), 1); mir::SubgraphVisualizer(graph.get(), subgraphs)(); + ASSERT_EQ(subgraphs.size(), 1); } } // namespace lite diff --git a/lite/utils/env.h b/lite/utils/env.h index 86af8c9e7e..3048c84b42 100644 --- a/lite/utils/env.h +++ b/lite/utils/env.h @@ -19,6 +19,9 @@ #include #include +#define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \ + "SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE" + namespace paddle { namespace lite { -- GitLab