未验证 提交 d60b8d61 编写于 作者: T tensor-tang 提交者: GitHub

refine the npu graph and subgraph (#1959)

* fix attr and refine subgraph pass test=develop

* refine the npu pass functions

* fix test

test=develop
上级 3191ec5e
......@@ -134,61 +134,6 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
return model_name;
}
cpp::OpDesc GenerateNPUProgramPass::GenGraphOpDesc(
const std::string& model_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
op_desc.SetInput("Inputs", in_var_names);
op_desc.SetOutput("Outputs", out_var_names);
op_desc.SetAttr("model_name", model_name);
return op_desc;
}
void GenerateNPUProgramPass::InsertNewNode(
const std::unique_ptr<SSAGraph>& graph,
const std::string& model_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars) {
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (auto i : in_data_vars) {
in_var_names.push_back(i->AsArg().name);
}
for (auto i : out_data_vars) {
out_var_names.push_back(i->AsArg().name);
}
auto op_desc = GenGraphOpDesc(model_name, in_var_names, out_var_names);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
graph_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& in_var : in_data_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& in_var : in_wgt_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& out_var : out_data_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
for (auto& out_var : out_unused_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
}
void GenerateNPUProgramPass::GenNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
......@@ -219,29 +164,8 @@ void GenerateNPUProgramPass::GenNPUSubgraph(
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateNPUProgramPass::GenAllNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) {
std::unordered_map<int, std::unordered_set<Node*>> all_op_nodes;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
if (all_op_nodes.count(sub_id) == 0) {
all_op_nodes[sub_id] = std::unordered_set<Node*>();
}
all_op_nodes.at(sub_id).insert(item);
}
for (int id = 1; id <= sub_num; ++id) {
LOG(INFO) << "Converting subgraph_id:" << id;
GenNPUSubgraph(graph, all_op_nodes.at(id), id);
}
}
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
VLOG(3) << "Before NPU Pass \n" << Visualize(graph.get());
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
......@@ -252,15 +176,22 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
try {
int num_subgraph = FuseSubgraph(graph, supported_op_types);
LOG(INFO) << "detected " << num_subgraph << " NPU subgraph";
InferOnce(graph);
GenAllNPUSubgraph(graph, num_subgraph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "Converting subgraph_id:" << id;
GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
} catch (...) {
LOG(WARNING) << "Build NPU graph failed";
throw std::runtime_error("Build NPU graph failed");
}
VLOG(3) << "After NPU Pass \n" << Visualize(graph.get());
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
......
......@@ -51,25 +51,10 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
cpp::OpDesc GenGraphOpDesc(const std::string& model_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names);
void InsertNewNode(const std::unique_ptr<SSAGraph>& graph,
const std::string& model_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars);
void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& nodes_all,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
void GenAllNPUSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num);
private:
std::vector<Instruction> insts_;
};
......
......@@ -26,6 +26,77 @@ namespace lite {
namespace mir {
namespace subgraph {
std::unordered_map<int, std::unordered_set<Node*>>
SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_map<int, std::unordered_set<Node*>> op_nodes;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
if (!op_nodes.count(sub_id)) {
op_nodes[sub_id] = std::unordered_set<Node*>();
}
op_nodes.at(sub_id).insert(item);
}
return op_nodes;
}
cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc(
const std::string& model_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
op_desc.SetInput("Inputs", in_var_names);
op_desc.SetOutput("Outputs", out_var_names);
op_desc.SetAttr("model_name", model_name);
return op_desc;
}
void SubgraphProgramPass::InsertNewNode(
const std::unique_ptr<SSAGraph>& graph,
const std::string& model_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars) {
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (auto i : in_data_vars) {
in_var_names.push_back(i->AsArg().name);
}
for (auto i : out_data_vars) {
out_var_names.push_back(i->AsArg().name);
}
auto op_desc = GenGraphOpDesc(model_name, in_var_names, out_var_names);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
graph_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& in_var : in_data_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& in_var : in_wgt_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& out_var : out_data_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
for (auto& out_var : out_unused_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
}
void SubgraphProgramPass::SortHelper(
Node* node,
const std::unordered_set<Node*>& nodes_all,
......@@ -170,7 +241,6 @@ void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node,
auto& stmt = node->AsStmt();
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
nodes2rm_[to_id].insert(node);
for (auto& i : node->outlinks) {
ChangeAllOutConnectedID(i, to_id, from_id);
}
......@@ -191,22 +261,12 @@ void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node,
if (!all_out_op_supported) {
return;
}
nodes2rm_[to_id].insert(node);
for (auto& i : node->outlinks) {
CHECK(i->IsStmt());
auto& stmt = i->AsStmt();
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
nodes2rm_[to_id].insert(i);
for (auto& o : i->outlinks) {
for (auto& j : o->outlinks) {
if (j->IsStmt()) {
auto& Nstmt = j->AsStmt();
if (Nstmt.subgraph_id() < from_id) {
o_nodes_[to_id].insert(o);
}
}
}
ChangeAllOutConnectedID(o, to_id, from_id);
}
}
......@@ -230,47 +290,11 @@ int SubgraphProgramPass::FuseSubgraphID(
}
}
}
if (inputvar == 1) {
for (auto& i : item->outlinks) i_nodes_[sub_id].insert(i);
}
}
if (stmt.subgraph_id() != 0) continue;
ChangeAllOutConnectedID(item, sub_id);
sub_id++;
}
for (auto& i : nodes2rm_) {
for (auto& item : i.second) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << "nodes2rm_:" << stmt.op_type();
} else if (item->IsArg()) {
auto& arg = item->AsArg();
LOG(INFO) << "nodes2rm_:" << arg.name;
}
}
}
for (auto& i : i_nodes_) {
for (auto& item : i.second) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << "i_nodes_: " << i.first << " " << stmt.op_type();
} else if (item->IsArg()) {
auto& arg = item->AsArg();
LOG(INFO) << "i_nodes_: " << i.first << " " << arg.name;
}
}
}
for (auto& i : o_nodes_) {
for (auto& item : i.second) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << "o_nodes_:" << i.first << " " << stmt.op_type();
} else if (item->IsArg()) {
auto& arg = item->AsArg();
LOG(INFO) << "o_nodes_: " << i.first << " " << arg.name;
}
}
}
return sub_id - 1;
}
......@@ -278,12 +302,7 @@ int SubgraphProgramPass::FuseSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types) {
InitSubgraphID(graph, supported_op_types);
nodes2rm_.clear();
i_nodes_.clear();
o_nodes_.clear();
int num_subgraph = FuseSubgraphID(graph);
LOG(INFO) << "detected " << num_subgraph << " subgraph";
return num_subgraph;
return FuseSubgraphID(graph);
}
} // namespace subgraph
} // namespace mir
......
......@@ -55,6 +55,24 @@ class SubgraphProgramPass : public ProgramPass {
void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0);
// Below function cloud be useful in child classes //
// classify node by subgraph id
std::unordered_map<int, std::unordered_set<Node*>> ClassifySubgraph(
const std::unique_ptr<SSAGraph>& graph);
// generate the graph op desc
cpp::OpDesc GenGraphOpDesc(const std::string& model_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names);
// insert a new graph op node
void InsertNewNode(const std::unique_ptr<SSAGraph>& graph,
const std::string& model_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars);
// Sort and return the topology order of nodes set
std::vector<Node*> GetTopologicalOrder(
......@@ -79,18 +97,6 @@ class SubgraphProgramPass : public ProgramPass {
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret);
// {1: {nodes2rm_in_subgraph1, ...},
// 2: {nodes2rm_in_subgraph2, ...}}
// delete nodes
std::unordered_map<int, std::unordered_set<Node*>> nodes2rm_;
// std::unordered_map<int, std::unordered_set<const Node*>> nodes2rm_;
// inputs nodes
std::unordered_map<int, std::unordered_set<Node*>> i_nodes_;
// std::unordered_map<int, std::unordered_set<const Node*>> i_nodes_;
// outputs nodes
std::unordered_map<int, std::unordered_set<Node*>> o_nodes_;
// std::unordered_map<int, std::unordered_set<const Node*>> o_nodes_;
};
} // namespace subgraph
......
......@@ -29,9 +29,15 @@ DEFINE_string(model_dir, "", "model_dir");
namespace paddle {
namespace lite {
TEST(SubgraphTest, mobilenetv2) {
TEST(SubgraphTest, models) {
cpp::ProgramDesc program_desc;
auto scope = std::make_shared<Scope>();
// LoadModelPb(FLAGS_model_dir,
// FLAGS_model_dir + "/model",
// FLAGS_model_dir + "/params",
// scope.get(),
// &program_desc,
// true);
LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc);
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册