提交 4623a9b9 编写于 作者: Z zhupengyang 提交者: GitHub

split subgraph by config file (#3157)

上级 99d7121f
......@@ -22,6 +22,9 @@
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/operators/subgraph_op.h"
#include "lite/utils/env.h"
#include "lite/utils/io.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
......@@ -209,8 +212,81 @@ void SubgraphDetector::FlexibleDFS(
}
}
std::unordered_set<Node *> SubgraphDetector::GetExcludedNodesFromConfigFile() {
// get exclude nodes from config file
std::unordered_set<Node *> excluded_nodes;
std::string config_file_path =
GetStringFromEnv(SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE);
if (!IsFileExists(config_file_path)) {
return excluded_nodes;
}
std::vector<std::string> lines = ReadLines(config_file_path);
for (std::string line : lines) {
std::vector<std::string> node_info = Split(line, ":");
std::string op_type = node_info.at(0);
std::vector<std::string> in_vars_name;
if (node_info.size() > 1) {
in_vars_name = Split(node_info.at(1), ",");
}
std::vector<std::string> out_vars_name;
if (node_info.size() > 2) {
out_vars_name = Split(node_info.at(2), ",");
}
for (auto &node : graph_->mutable_nodes()) {
if (node.IsArg()) continue;
auto stmt = node.stmt();
if (op_type != stmt->op_type()) continue;
auto in_nodes = node.inlinks;
auto out_nodes = node.outlinks;
if (in_vars_name.size() > in_nodes.size() ||
out_vars_name.size() > out_nodes.size()) {
continue;
}
bool matched = true;
for (auto in_var_name : in_vars_name) {
bool find_var = false;
for (auto *in_node : in_nodes) {
if (in_node->arg()->name == in_var_name) {
find_var = true;
break;
}
}
if (!find_var) {
matched = false;
break;
}
}
for (auto out_var_name : out_vars_name) {
bool find_var = false;
for (auto *out_node : out_nodes) {
if (out_node->arg()->name == out_var_name) {
find_var = true;
break;
}
}
if (!find_var) {
matched = false;
break;
}
}
if (matched) {
excluded_nodes.insert(&node);
}
}
}
return excluded_nodes;
}
void SubgraphDetector::InitNodes(node_map_t *nodes) {
// Initialize and mark the subgraph detector nodes based on teller.
std::unordered_set<Node *> excluded_nodes = GetExcludedNodesFromConfigFile();
for (auto &it : *nodes) {
for (auto &in_node : it.first->inlinks) {
it.second->inlinks.push_back((*nodes)[in_node]);
......@@ -218,7 +294,7 @@ void SubgraphDetector::InitNodes(node_map_t *nodes) {
for (auto &out_node : it.first->outlinks) {
it.second->outlinks.push_back((*nodes)[out_node]);
}
if (teller_(it.first)) {
if (teller_(it.first) && excluded_nodes.count(it.first) == 0) {
it.second->marked = true;
if (it.first->IsStmt()) {
// If a function is inside the subgraph, mark all the output variables
......
......@@ -63,6 +63,7 @@ class SubgraphDetector {
node_dat_t* UnionFindAncestor();
void UnionFindCombine(node_dat_t* candidate);
};
SubgraphDetector(SSAGraph* graph, const SubgraphTeller& teller)
: graph_(graph), teller_(teller) {}
std::vector<std::vector<Node*>> operator()();
......@@ -71,7 +72,11 @@ class SubgraphDetector {
bool reverse,
const std::function<bool(const node_dat_t*)>& enter,
const std::function<bool(const node_dat_t*)>& leave);
std::unordered_set<Node*> GetExcludedNodesFromConfigFile();
void InitNodes(node_map_t* nodes);
std::vector<std::vector<Node*>> ExtractSubgraphs(node_map_t* nodes);
protected:
......
......@@ -220,8 +220,8 @@ TEST(Subgraph, detect_custom_model) {
};
std::vector<std::vector<mir::Node*>> subgraphs =
mir::SubgraphDetector(graph.get(), teller)();
ASSERT_EQ(subgraphs.size(), 1);
mir::SubgraphVisualizer(graph.get(), subgraphs)();
ASSERT_EQ(subgraphs.size(), 1);
}
} // namespace lite
......
......@@ -19,6 +19,9 @@
#include <iostream>
#include <string>
#define SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE \
"SUBGRAPH_CUSTOM_PARTITION_CONFIG_FILE"
namespace paddle {
namespace lite {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册