提交 28aca4a8 编写于 作者: T tensor-tang 提交者: GitHub

[NPU] enable npu program rollback (#1906)

test=develop
上级 f3035827
...@@ -127,7 +127,8 @@ std::string GenerateNPUProgramPass::BuildNPUGraph( ...@@ -127,7 +127,8 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om"); std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) { if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
LOG(FATAL) << "Build NPU failed subgraph " << sub_id; LOG(WARNING) << "Build NPU failed subgraph " << sub_id;
throw std::runtime_error("Build NPU failed subgraph.");
} }
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id; LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id;
return model_name; return model_name;
...@@ -188,7 +189,7 @@ void GenerateNPUProgramPass::InsertNewNode( ...@@ -188,7 +189,7 @@ void GenerateNPUProgramPass::InsertNewNode(
ContextScheduler::Global().NewContext(inst.picked_kernel().target())); ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
} }
void GenerateNPUProgramPass::GenNPUGraphOpNode( void GenerateNPUProgramPass::GenNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph, const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes, const std::unordered_set<Node*>& op_nodes,
int sub_id) { int sub_id) {
...@@ -199,9 +200,6 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( ...@@ -199,9 +200,6 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
FindInputOutputVars( FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars); op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
auto model_name = auto model_name =
BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id); BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
...@@ -215,33 +213,34 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( ...@@ -215,33 +213,34 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
out_data_vars, out_data_vars,
out_unused_vars); out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm); GraphSafeRemoveNodes(graph.get(), nodes2rm);
} }
void GenerateNPUProgramPass::ConvertSubgraph( void GenerateNPUProgramPass::GenAllNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) { const std::unique_ptr<SSAGraph>& graph, int sub_num) {
std::unordered_map<int, std::unordered_set<Node*>> nodes_all; std::unordered_map<int, std::unordered_set<Node*>> all_op_nodes;
int ops_num = 0;
for (auto& item : graph->StmtTopologicalOrder()) { for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue; if (!item->IsStmt()) continue;
ops_num++;
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id(); int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue; if (sub_id < 1) continue;
if (nodes_all.count(sub_id) == 0) { if (all_op_nodes.count(sub_id) == 0) {
nodes_all[sub_id] = std::unordered_set<Node*>(); all_op_nodes[sub_id] = std::unordered_set<Node*>();
} }
nodes_all.at(sub_id).insert(item); all_op_nodes.at(sub_id).insert(item);
} }
for (int id = 1; id <= sub_num; ++id) { for (int id = 1; id <= sub_num; ++id) {
LOG(INFO) << "Converting subgraph_id:" << id; LOG(INFO) << "Converting subgraph_id:" << id;
GenNPUGraphOpNode(graph, nodes_all.at(id), id); GenNPUSubgraph(graph, all_op_nodes.at(id), id);
} }
} }
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); VLOG(3) << "Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& op_map = bridges.AllFunctions(); const auto& op_map = bridges.AllFunctions();
...@@ -254,16 +253,14 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -254,16 +253,14 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
try { try {
int num_subgraph = FuseSubgraph(graph, supported_op_types); int num_subgraph = FuseSubgraph(graph, supported_op_types);
LOG(INFO) << "detected " << num_subgraph << " NPU subgraph"; LOG(INFO) << "detected " << num_subgraph << " NPU subgraph";
InferOnce(graph); InferOnce(graph);
ConvertSubgraph(graph, num_subgraph); GenAllNPUSubgraph(graph, num_subgraph);
} catch (...) { } catch (...) {
// exception = true;
LOG(WARNING) << "Build NPU graph failed"; LOG(WARNING) << "Build NPU graph failed";
throw std::runtime_error("Build NPU graph failed");
} }
LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get()); VLOG(3) << "After NPU Pass \n" << Visualize(graph.get());
for (auto& item : graph->StmtTopologicalOrder()) { for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) { if (item->IsStmt()) {
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
......
...@@ -64,11 +64,11 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { ...@@ -64,11 +64,11 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
std::unordered_set<Node*> out_data_vars, std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars); std::unordered_set<Node*> out_unused_vars);
void GenNPUGraphOpNode(const std::unique_ptr<SSAGraph>& graph, void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& nodes_all, const std::unordered_set<Node*>& nodes_all,
int sub_id); int sub_id);
void ConvertSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num); void GenAllNPUSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num);
private: private:
std::vector<Instruction> insts_; std::vector<Instruction> insts_;
......
...@@ -118,18 +118,16 @@ class Optimizer { ...@@ -118,18 +118,16 @@ class Optimizer {
auto pass = mir::PassManager::Global() auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateNPUProgramPass>( .LookUp<mir::subgraph::GenerateNPUProgramPass>(
"generate_npu_program_pass"); "generate_npu_program_pass");
try {
pass->Apply(graph_); pass->Apply(graph_);
auto program = pass->GenProgram(); auto program = pass->GenProgram();
if (program) {
CHECK(exec_scope_); CHECK(exec_scope_);
program->set_exec_scope(exec_scope_); program->set_exec_scope(exec_scope_);
return program; return program;
} else { } catch (...) {
LOG(WARNING) << "Build NPU graph failed."; LOG(WARNING) << "Build NPU graph failed";
} }
} }
#endif #endif
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>( auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass"); "generate_program_pass");
......
...@@ -56,6 +56,7 @@ bool BuildNPUClient(const void* om_model_data, ...@@ -56,6 +56,7 @@ bool BuildNPUClient(const void* om_model_data,
if (ret != hiai::AI_SUCCESS) { if (ret != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] Failed building NPU client " << name LOG(WARNING) << "[NPU] Failed building NPU client " << name
<< ", ret: " << ret; << ", ret: " << ret;
throw std::runtime_error("");
return false; return false;
} }
...@@ -71,6 +72,7 @@ bool BuildNPUClient(const void* om_model_data, ...@@ -71,6 +72,7 @@ bool BuildNPUClient(const void* om_model_data,
model_desc.push_back(desc); model_desc.push_back(desc);
if (client->Load(model_desc) != hiai::AI_SUCCESS) { if (client->Load(model_desc) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] Model Load Failed: " << desc->GetName(); LOG(WARNING) << "[NPU] Model Load Failed: " << desc->GetName();
throw std::runtime_error("");
return false; return false;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册