diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.cc b/lite/core/mir/subgraph/generate_npu_program_pass.cc index bbf9a21854fdccc2a48f9bfe4f6e76affcfd6965..e7322087afae9098abe61d3769a7312737e7d7eb 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.cc +++ b/lite/core/mir/subgraph/generate_npu_program_pass.cc @@ -134,61 +134,6 @@ std::string GenerateNPUProgramPass::BuildNPUGraph( return model_name; } -cpp::OpDesc GenerateNPUProgramPass::GenGraphOpDesc( - const std::string& model_name, - const std::vector& in_var_names, - const std::vector& out_var_names) { - cpp::OpDesc op_desc; - op_desc.SetType("graph_op"); - op_desc.SetInput("Inputs", in_var_names); - op_desc.SetOutput("Outputs", out_var_names); - op_desc.SetAttr("model_name", model_name); - return op_desc; -} - -void GenerateNPUProgramPass::InsertNewNode( - const std::unique_ptr& graph, - const std::string& model_name, - Scope* scope, - const std::vector& valid_places, - std::unordered_set in_data_vars, - std::unordered_set in_wgt_vars, - std::unordered_set out_data_vars, - std::unordered_set out_unused_vars) { - std::vector in_var_names; - std::vector out_var_names; - for (auto i : in_data_vars) { - in_var_names.push_back(i->AsArg().name); - } - for (auto i : out_data_vars) { - out_var_names.push_back(i->AsArg().name); - } - - auto op_desc = GenGraphOpDesc(model_name, in_var_names, out_var_names); - - auto graph_op = LiteOpRegistry::Global().Create("graph_op"); - graph_op->Attach(op_desc, scope); - auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); - - for (auto& in_var : in_data_vars) { - IR_NODE_LINK_TO(in_var, new_op_node); - } - for (auto& in_var : in_wgt_vars) { - IR_NODE_LINK_TO(in_var, new_op_node); - } - for (auto& out_var : out_data_vars) { - IR_OP_VAR_LINK(new_op_node, out_var); - } - for (auto& out_var : out_unused_vars) { - IR_OP_VAR_LINK(new_op_node, out_var); - } - - // assign context - auto& inst = new_op_node->AsStmt(); - inst.picked_kernel().SetContext( - ContextScheduler::Global().NewContext(inst.picked_kernel().target())); -} - void GenerateNPUProgramPass::GenNPUSubgraph( const std::unique_ptr& graph, const std::unordered_set& op_nodes, @@ -219,29 +164,8 @@ void GenerateNPUProgramPass::GenNPUSubgraph( GraphSafeRemoveNodes(graph.get(), nodes2rm); } -void GenerateNPUProgramPass::GenAllNPUSubgraph( - const std::unique_ptr& graph, int sub_num) { - std::unordered_map> all_op_nodes; - for (auto& item : graph->StmtTopologicalOrder()) { - if (!item->IsStmt()) continue; - auto& stmt = item->AsStmt(); - int sub_id = stmt.subgraph_id(); - if (sub_id < 1) continue; - if (all_op_nodes.count(sub_id) == 0) { - all_op_nodes[sub_id] = std::unordered_set(); - } - all_op_nodes.at(sub_id).insert(item); - } - - for (int id = 1; id <= sub_num; ++id) { - LOG(INFO) << "Converting subgraph_id:" << id; - GenNPUSubgraph(graph, all_op_nodes.at(id), id); - } -} - void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { - VLOG(3) << "Before NPU Pass \n" << Visualize(graph.get()); - + LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& op_map = bridges.AllFunctions(); std::vector supported_op_types; @@ -252,15 +176,22 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr& graph) { try { int num_subgraph = FuseSubgraph(graph, supported_op_types); - LOG(INFO) << "detected " << num_subgraph << " NPU subgraph"; InferOnce(graph); - GenAllNPUSubgraph(graph, num_subgraph); + auto op_nodes_all = ClassifySubgraph(graph); + CHECK_EQ(op_nodes_all.size(), num_subgraph); + int id = 1; + for (auto& op_nodes : op_nodes_all) { + LOG(INFO) << "Converting subgraph_id:" << id; + GenNPUSubgraph(graph, op_nodes.second, id); + LOG(INFO) << "After NPU Pass Subgraph " << id << "\n" + << Visualize(graph.get()); + id++; + } } catch (...) { LOG(WARNING) << "Build NPU graph failed"; throw std::runtime_error("Build NPU graph failed"); } - VLOG(3) << "After NPU Pass \n" << Visualize(graph.get()); for (auto& item : graph->StmtTopologicalOrder()) { if (item->IsStmt()) { auto& stmt = item->AsStmt(); diff --git a/lite/core/mir/subgraph/generate_npu_program_pass.h b/lite/core/mir/subgraph/generate_npu_program_pass.h index 45c04b0bfe226389c208e72f6bbb6f0037786b6a..9e030287cb7d91a06bc930f9c1daefb06b3d6965 100644 --- a/lite/core/mir/subgraph/generate_npu_program_pass.h +++ b/lite/core/mir/subgraph/generate_npu_program_pass.h @@ -51,25 +51,10 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { const std::unordered_set& out_data_vars, int sub_id); - cpp::OpDesc GenGraphOpDesc(const std::string& model_name, - const std::vector& in_var_names, - const std::vector& out_var_names); - - void InsertNewNode(const std::unique_ptr& graph, - const std::string& model_name, - Scope* scope, - const std::vector& valid_places, - std::unordered_set in_data_vars, - std::unordered_set in_wgt_vars, - std::unordered_set out_data_vars, - std::unordered_set out_unused_vars); - void GenNPUSubgraph(const std::unique_ptr& graph, - const std::unordered_set& nodes_all, + const std::unordered_set& op_nodes, int sub_id); - void GenAllNPUSubgraph(const std::unique_ptr& graph, int sub_num); - private: std::vector insts_; }; diff --git a/lite/core/mir/subgraph/subgraph_program_pass.cc b/lite/core/mir/subgraph/subgraph_program_pass.cc index 3947d3b5828414014801618350b15c623d706f26..dddcdad7efc0b518e8c6396b2724808186adc2c2 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass.cc @@ -26,6 +26,77 @@ namespace lite { namespace mir { namespace subgraph { +std::unordered_map> +SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr& graph) { + std::unordered_map> op_nodes; + for (auto& item : graph->StmtTopologicalOrder()) { + if (!item->IsStmt()) continue; + auto& stmt = item->AsStmt(); + int sub_id = stmt.subgraph_id(); + if (sub_id < 1) continue; + if (!op_nodes.count(sub_id)) { + op_nodes[sub_id] = std::unordered_set(); + } + op_nodes.at(sub_id).insert(item); + } + return op_nodes; +} + +cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc( + const std::string& model_name, + const std::vector& in_var_names, + const std::vector& out_var_names) { + cpp::OpDesc op_desc; + op_desc.SetType("graph_op"); + op_desc.SetInput("Inputs", in_var_names); + op_desc.SetOutput("Outputs", out_var_names); + op_desc.SetAttr("model_name", model_name); + return op_desc; +} + +void SubgraphProgramPass::InsertNewNode( + const std::unique_ptr& graph, + const std::string& model_name, + Scope* scope, + const std::vector& valid_places, + std::unordered_set in_data_vars, + std::unordered_set in_wgt_vars, + std::unordered_set out_data_vars, + std::unordered_set out_unused_vars) { + std::vector in_var_names; + std::vector out_var_names; + for (auto i : in_data_vars) { + in_var_names.push_back(i->AsArg().name); + } + for (auto i : out_data_vars) { + out_var_names.push_back(i->AsArg().name); + } + + auto op_desc = GenGraphOpDesc(model_name, in_var_names, out_var_names); + + auto graph_op = LiteOpRegistry::Global().Create("graph_op"); + graph_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); + + for (auto& in_var : in_data_vars) { + IR_NODE_LINK_TO(in_var, new_op_node); + } + for (auto& in_var : in_wgt_vars) { + IR_NODE_LINK_TO(in_var, new_op_node); + } + for (auto& out_var : out_data_vars) { + IR_OP_VAR_LINK(new_op_node, out_var); + } + for (auto& out_var : out_unused_vars) { + IR_OP_VAR_LINK(new_op_node, out_var); + } + + // assign context + auto& inst = new_op_node->AsStmt(); + inst.picked_kernel().SetContext( + ContextScheduler::Global().NewContext(inst.picked_kernel().target())); +} + void SubgraphProgramPass::SortHelper( Node* node, const std::unordered_set& nodes_all, @@ -170,7 +241,6 @@ void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node, auto& stmt = node->AsStmt(); if (stmt.subgraph_id() == from_id) { stmt.SetSubgraphID(to_id); - nodes2rm_[to_id].insert(node); for (auto& i : node->outlinks) { ChangeAllOutConnectedID(i, to_id, from_id); } @@ -191,22 +261,12 @@ void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node, if (!all_out_op_supported) { return; } - nodes2rm_[to_id].insert(node); for (auto& i : node->outlinks) { CHECK(i->IsStmt()); auto& stmt = i->AsStmt(); if (stmt.subgraph_id() == from_id) { stmt.SetSubgraphID(to_id); - nodes2rm_[to_id].insert(i); for (auto& o : i->outlinks) { - for (auto& j : o->outlinks) { - if (j->IsStmt()) { - auto& Nstmt = j->AsStmt(); - if (Nstmt.subgraph_id() < from_id) { - o_nodes_[to_id].insert(o); - } - } - } ChangeAllOutConnectedID(o, to_id, from_id); } } @@ -230,47 +290,11 @@ int SubgraphProgramPass::FuseSubgraphID( } } } - if (inputvar == 1) { - for (auto& i : item->outlinks) i_nodes_[sub_id].insert(i); - } } if (stmt.subgraph_id() != 0) continue; ChangeAllOutConnectedID(item, sub_id); sub_id++; } - for (auto& i : nodes2rm_) { - for (auto& item : i.second) { - if (item->IsStmt()) { - auto& stmt = item->AsStmt(); - LOG(INFO) << "nodes2rm_:" << stmt.op_type(); - } else if (item->IsArg()) { - auto& arg = item->AsArg(); - LOG(INFO) << "nodes2rm_:" << arg.name; - } - } - } - for (auto& i : i_nodes_) { - for (auto& item : i.second) { - if (item->IsStmt()) { - auto& stmt = item->AsStmt(); - LOG(INFO) << "i_nodes_: " << i.first << " " << stmt.op_type(); - } else if (item->IsArg()) { - auto& arg = item->AsArg(); - LOG(INFO) << "i_nodes_: " << i.first << " " << arg.name; - } - } - } - for (auto& i : o_nodes_) { - for (auto& item : i.second) { - if (item->IsStmt()) { - auto& stmt = item->AsStmt(); - LOG(INFO) << "o_nodes_:" << i.first << " " << stmt.op_type(); - } else if (item->IsArg()) { - auto& arg = item->AsArg(); - LOG(INFO) << "o_nodes_: " << i.first << " " << arg.name; - } - } - } return sub_id - 1; } @@ -278,12 +302,7 @@ int SubgraphProgramPass::FuseSubgraph( const std::unique_ptr& graph, const std::vector& supported_op_types) { InitSubgraphID(graph, supported_op_types); - nodes2rm_.clear(); - i_nodes_.clear(); - o_nodes_.clear(); - int num_subgraph = FuseSubgraphID(graph); - LOG(INFO) << "detected " << num_subgraph << " subgraph"; - return num_subgraph; + return FuseSubgraphID(graph); } } // namespace subgraph } // namespace mir diff --git a/lite/core/mir/subgraph/subgraph_program_pass.h b/lite/core/mir/subgraph/subgraph_program_pass.h index 5bf477544d623a28bfa8ce617cbc52deb6b3779f..51e9367539caa6f0868138235bc7b0907c189df5 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass.h +++ b/lite/core/mir/subgraph/subgraph_program_pass.h @@ -55,6 +55,24 @@ class SubgraphProgramPass : public ProgramPass { void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0); // Below function cloud be useful in child classes // + // classify node by subgraph id + std::unordered_map> ClassifySubgraph( + const std::unique_ptr& graph); + + // generate the graph op desc + cpp::OpDesc GenGraphOpDesc(const std::string& model_name, + const std::vector& in_var_names, + const std::vector& out_var_names); + + // insert a new graph op node + void InsertNewNode(const std::unique_ptr& graph, + const std::string& model_name, + Scope* scope, + const std::vector& valid_places, + std::unordered_set in_data_vars, + std::unordered_set in_wgt_vars, + std::unordered_set out_data_vars, + std::unordered_set out_unused_vars); // Sort and return the topology order of nodes set std::vector GetTopologicalOrder( @@ -79,18 +97,6 @@ class SubgraphProgramPass : public ProgramPass { const std::unordered_set& nodes_all, std::unordered_set* visited_nodes, std::vector* ret); - - // {1: {nodes2rm_in_subgraph1, ...}, - // 2: {nodes2rm_in_subgraph2, ...}} - // delete nodes - std::unordered_map> nodes2rm_; - // std::unordered_map> nodes2rm_; - // inputs nodes - std::unordered_map> i_nodes_; - // std::unordered_map> i_nodes_; - // outputs nodes - std::unordered_map> o_nodes_; - // std::unordered_map> o_nodes_; }; } // namespace subgraph diff --git a/lite/core/mir/subgraph/subgraph_program_pass_test.cc b/lite/core/mir/subgraph/subgraph_program_pass_test.cc index 3d8afc0c05e7a7378d357a98cf4699cb77ffa134..de4acec91d3eacd5f880d6495367a4826eb90cfa 100644 --- a/lite/core/mir/subgraph/subgraph_program_pass_test.cc +++ b/lite/core/mir/subgraph/subgraph_program_pass_test.cc @@ -29,9 +29,15 @@ DEFINE_string(model_dir, "", "model_dir"); namespace paddle { namespace lite { -TEST(SubgraphTest, mobilenetv2) { +TEST(SubgraphTest, models) { cpp::ProgramDesc program_desc; auto scope = std::make_shared(); + // LoadModelPb(FLAGS_model_dir, + // FLAGS_model_dir + "/model", + // FLAGS_model_dir + "/params", + // scope.get(), + // &program_desc, + // true); LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc); std::vector valid_places({ Place{TARGET(kHost), PRECISION(kFloat)},