提交 28aca4a8 编写于 作者: T tensor-tang 提交者: GitHub

[NPU] enable npu program rollback (#1906)

test=develop
上级 f3035827
......@@ -127,7 +127,8 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
LOG(FATAL) << "Build NPU failed subgraph " << sub_id;
LOG(WARNING) << "Build NPU failed subgraph " << sub_id;
throw std::runtime_error("Build NPU failed subgraph.");
}
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id;
return model_name;
......@@ -188,7 +189,7 @@ void GenerateNPUProgramPass::InsertNewNode(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
}
void GenerateNPUProgramPass::GenNPUGraphOpNode(
void GenerateNPUProgramPass::GenNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
......@@ -199,9 +200,6 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
auto model_name =
BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
......@@ -215,33 +213,34 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateNPUProgramPass::ConvertSubgraph(
void GenerateNPUProgramPass::GenAllNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) {
std::unordered_map<int, std::unordered_set<Node*>> nodes_all;
int ops_num = 0;
std::unordered_map<int, std::unordered_set<Node*>> all_op_nodes;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
ops_num++;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
if (nodes_all.count(sub_id) == 0) {
nodes_all[sub_id] = std::unordered_set<Node*>();
if (all_op_nodes.count(sub_id) == 0) {
all_op_nodes[sub_id] = std::unordered_set<Node*>();
}
nodes_all.at(sub_id).insert(item);
all_op_nodes.at(sub_id).insert(item);
}
for (int id = 1; id <= sub_num; ++id) {
LOG(INFO) << "Converting subgraph_id:" << id;
GenNPUGraphOpNode(graph, nodes_all.at(id), id);
GenNPUSubgraph(graph, all_op_nodes.at(id), id);
}
}
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get());
VLOG(3) << "Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
......@@ -254,16 +253,14 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
try {
int num_subgraph = FuseSubgraph(graph, supported_op_types);
LOG(INFO) << "detected " << num_subgraph << " NPU subgraph";
InferOnce(graph);
ConvertSubgraph(graph, num_subgraph);
GenAllNPUSubgraph(graph, num_subgraph);
} catch (...) {
// exception = true;
LOG(WARNING) << "Build NPU graph failed";
throw std::runtime_error("Build NPU graph failed");
}
LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
VLOG(3) << "After NPU Pass \n" << Visualize(graph.get());
for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
......
......@@ -64,11 +64,11 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars);
void GenNPUGraphOpNode(const std::unique_ptr<SSAGraph>& graph,
void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& nodes_all,
int sub_id);
void ConvertSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num);
void GenAllNPUSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num);
private:
std::vector<Instruction> insts_;
......
......@@ -118,18 +118,16 @@ class Optimizer {
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateNPUProgramPass>(
"generate_npu_program_pass");
try {
pass->Apply(graph_);
auto program = pass->GenProgram();
if (program) {
CHECK(exec_scope_);
program->set_exec_scope(exec_scope_);
return program;
} else {
LOG(WARNING) << "Build NPU graph failed.";
} catch (...) {
LOG(WARNING) << "Build NPU graph failed";
}
}
#endif
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
"generate_program_pass");
......
......@@ -56,6 +56,7 @@ bool BuildNPUClient(const void* om_model_data,
if (ret != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] Failed building NPU client " << name
<< ", ret: " << ret;
throw std::runtime_error("");
return false;
}
......@@ -71,6 +72,7 @@ bool BuildNPUClient(const void* om_model_data,
model_desc.push_back(desc);
if (client->Load(model_desc) != hiai::AI_SUCCESS) {
LOG(WARNING) << "[NPU] Model Load Failed: " << desc->GetName();
throw std::runtime_error("");
return false;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册