未验证 提交 61836c46 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support generate multiple IO subgraph (#1828)

test=develop
上级 5425ddca
...@@ -37,7 +37,7 @@ namespace lite { ...@@ -37,7 +37,7 @@ namespace lite {
namespace mir { namespace mir {
namespace subgraph { namespace subgraph {
void GenerateNPUProgramPass::NPUSortHelper( void GenerateNPUProgramPass::SubgraphSortHelper(
Node* node, Node* node,
const std::unordered_set<Node*>& nodes_all, const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes, std::unordered_set<const Node*>* visited_nodes,
...@@ -46,7 +46,7 @@ void GenerateNPUProgramPass::NPUSortHelper( ...@@ -46,7 +46,7 @@ void GenerateNPUProgramPass::NPUSortHelper(
if (var_node->inlinks.empty()) continue; if (var_node->inlinks.empty()) continue;
auto* op_node = var_node->inlinks.front(); auto* op_node = var_node->inlinks.front();
if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) { if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) {
NPUSortHelper(op_node, nodes_all, visited_nodes, ret); SubgraphSortHelper(op_node, nodes_all, visited_nodes, ret);
} }
} }
ret->push_back(node); ret->push_back(node);
...@@ -55,40 +55,68 @@ void GenerateNPUProgramPass::NPUSortHelper( ...@@ -55,40 +55,68 @@ void GenerateNPUProgramPass::NPUSortHelper(
void GenerateNPUProgramPass::CvtOpNodes( void GenerateNPUProgramPass::CvtOpNodes(
const std::vector<Node*>& nodes2cvt, const std::vector<Node*>& nodes2cvt,
std::vector<std::string>* in_vars_name, lite::npu::bridge::node_map_type* cvted_vars) {
std::vector<std::string>* out_vars_name,
lite::npu::bridge::node_map_type* cvted_vars,
std::unordered_set<const Node*>* nodes2rm) {
const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions(); const auto& cvtfunc_map = bridges.AllFunctions();
// record all converted vars
// op node's inputs must be found in cvted_vars
for (auto& node : nodes2cvt) { for (auto& node : nodes2cvt) {
lite::npu::bridge::node_map_type node_inputs; lite::npu::bridge::node_map_type node_inputs;
auto& stmt = node->AsStmt(); auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) { for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg(); auto& arg = var_node->AsArg();
if (arg.is_weight) continue;
auto var_name = arg.name; auto var_name = arg.name;
if (!cvted_vars->count(var_name)) { if (!cvted_vars->count(var_name)) {
if (arg.is_weight) continue;
cvted_vars->insert(std::make_pair( cvted_vars->insert(std::make_pair(
var_name, var_name,
lite::npu::bridge::CvtNode(var_node, stmt.op()->scope()))); lite::npu::bridge::CvtNode(var_node, stmt.op()->scope())));
in_vars_name->push_back(var_name);
} }
node_inputs.insert(*cvted_vars->find(var_name)); node_inputs.insert(*cvted_vars->find(var_name));
} }
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs); auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs);
cvted_vars->insert(node_outputs.begin(), node_outputs.end()); cvted_vars->insert(node_outputs.begin(), node_outputs.end());
nodes2rm->insert(node);
for (auto& var_node : node->outlinks) {
for (auto& next_op_node : var_node->outlinks) {
if (std::find(nodes2cvt.begin(), nodes2cvt.end(), next_op_node) ==
nodes2cvt.end()) {
out_vars_name->push_back(var_node->AsArg().name);
break;
} }
}
void GenerateNPUProgramPass::GetIOVars(
const std::vector<Node*>& nodes2cvt,
const lite::npu::bridge::node_map_type& cvted_vars,
std::unordered_set<const Node*>* nodes2rm,
std::vector<Node*>* in_vars,
std::vector<Node*>* out_vars,
lite::npu::bridge::node_map_type* in_cvted_vars,
lite::npu::bridge::node_map_type* out_cvted_vars) {
std::unordered_set<Node*> op_nodes_all(nodes2cvt.begin(), nodes2cvt.end());
for (auto& op_node : nodes2cvt) {
for (auto& in_var : op_node->inlinks) {
if (in_var->AsArg().is_weight) continue;
auto* pre_op_node = in_var->inlinks.front();
if (op_nodes_all.count(pre_op_node)) {
nodes2rm->insert(in_var);
continue;
}
in_vars->push_back(in_var);
auto arg_name = in_var->AsArg().name;
in_cvted_vars->insert(std::make_pair(arg_name, cvted_vars.at(arg_name)));
} }
for (auto& out_var : op_node->outlinks) {
if (out_var->outlinks.empty()) {
nodes2rm->insert(out_var);
continue;
}
auto* next_op_node = out_var->outlinks.front();
if (op_nodes_all.count(next_op_node)) {
nodes2rm->insert(out_var);
continue;
}
out_vars->push_back(out_var);
auto arg_name = out_var->AsArg().name;
out_cvted_vars->insert(std::make_pair(arg_name, cvted_vars.at(arg_name)));
} }
} }
nodes2rm->insert(nodes2cvt.begin(), nodes2cvt.end());
} }
void GenerateNPUProgramPass::GenNPUGraphOpNode( void GenerateNPUProgramPass::GenNPUGraphOpNode(
...@@ -100,23 +128,38 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( ...@@ -100,23 +128,38 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
for (auto& node : nodes_all) { for (auto& node : nodes_all) {
if (!node->IsStmt()) continue; if (!node->IsStmt()) continue;
if (visited_nodes.count(node)) continue; if (visited_nodes.count(node)) continue;
NPUSortHelper(node, nodes_all, &visited_nodes, &ret); SubgraphSortHelper(node, nodes_all, &visited_nodes, &ret);
} }
std::vector<std::string> in_vars_name;
std::vector<std::string> out_vars_name;
lite::npu::bridge::node_map_type cvted_vars; lite::npu::bridge::node_map_type cvted_vars;
CvtOpNodes(ret, &cvted_vars);
std::unordered_set<const Node*> nodes2rm; std::unordered_set<const Node*> nodes2rm;
CvtOpNodes(ret, &in_vars_name, &out_vars_name, &cvted_vars, &nodes2rm); std::vector<Node*> in_vars;
// insert new graph op node std::vector<Node*> out_vars;
lite::npu::bridge::node_map_type in_cvted_vars;
lite::npu::bridge::node_map_type out_cvted_vars;
GetIOVars(ret,
cvted_vars,
&nodes2rm,
&in_vars,
&out_vars,
&in_cvted_vars,
&out_cvted_vars);
std::vector<std::string> in_vars_name;
std::vector<std::string> out_vars_name;
std::vector<ge::Operator> inputs; std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs; std::vector<ge::Operator> outputs;
for (auto i : in_vars_name) { for (auto i : in_cvted_vars) {
inputs.push_back(*cvted_vars.at(i)); in_vars_name.push_back(i.first);
inputs.push_back(*i.second);
} }
for (auto i : out_vars_name) { for (auto i : out_cvted_vars) {
outputs.push_back(*cvted_vars.at(i)); out_vars_name.push_back(i.first);
outputs.push_back(*i.second);
} }
std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om"); std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) { if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
LOG(FATAL) << "Build NPU failed subgraph " << sub_id; LOG(FATAL) << "Build NPU failed subgraph " << sub_id;
...@@ -125,27 +168,25 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( ...@@ -125,27 +168,25 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
op_desc.SetType("graph_op"); op_desc.SetType("graph_op");
std::vector<std::string> in_var_names;
op_desc.SetInput("Inputs", in_vars_name); op_desc.SetInput("Inputs", in_vars_name);
op_desc.SetOutput("Outputs", out_vars_name); op_desc.SetOutput("Outputs", out_vars_name);
op_desc.SetAttr("model_name", model_name); op_desc.SetAttr("model_name", model_name);
auto graph_op = LiteOpRegistry::Global().Create("graph_op"); auto graph_op = LiteOpRegistry::Global().Create("graph_op");
// TODO(zpy): support multi inputs op
auto start_op = ret.front()->AsStmt().op(); auto any_op = ret.front()->AsStmt().op();
auto* scope = start_op->scope(); auto* scope = any_op->scope();
graph_op->Attach(op_desc, scope); graph_op->Attach(op_desc, scope);
auto valid_places = start_op->valid_places(); auto valid_places = any_op->valid_places();
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places); auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& var_node : ret.front()->inlinks) { for (auto& in_var : in_vars) {
auto& arg = var_node->AsArg(); IR_NODE_LINK_TO(in_var, new_op_node);
if (arg.is_weight) continue;
IR_NODE_LINK_TO(var_node, new_op_node);
} }
for (auto& var_node : ret.back()->outlinks) { for (auto& out_var : out_vars) {
auto& arg = var_node->AsArg(); IR_OP_VAR_LINK(new_op_node, out_var);
if (arg.is_weight) continue;
IR_NODE_LINK_TO(var_node, new_op_node);
} }
// assign context // assign context
...@@ -159,8 +200,10 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode( ...@@ -159,8 +200,10 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
void GenerateNPUProgramPass::ConvertSubgraph( void GenerateNPUProgramPass::ConvertSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) { const std::unique_ptr<SSAGraph>& graph, int sub_num) {
std::unordered_map<int, std::unordered_set<Node*>> nodes_all; std::unordered_map<int, std::unordered_set<Node*>> nodes_all;
int ops_num = 0;
for (auto& item : graph->StmtTopologicalOrder()) { for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue; if (!item->IsStmt()) continue;
ops_num++;
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id(); int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue; if (sub_id < 1) continue;
...@@ -178,6 +221,7 @@ void GenerateNPUProgramPass::ConvertSubgraph( ...@@ -178,6 +221,7 @@ void GenerateNPUProgramPass::ConvertSubgraph(
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get()); LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::npu::bridge::Factory::Instance(); const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& op_map = bridges.AllFunctions(); const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types; std::vector<std::string> supported_op_types;
...@@ -215,5 +259,3 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() { ...@@ -215,5 +259,3 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS(generate_npu_program_pass, REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass); paddle::lite::mir::subgraph::GenerateNPUProgramPass);
// USE_LITE_OP(graph_op);
...@@ -38,21 +38,27 @@ class GenerateNPUProgramPass : public SubgraphProgramPass { ...@@ -38,21 +38,27 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
std::unique_ptr<RuntimeProgram> GenProgram(); std::unique_ptr<RuntimeProgram> GenProgram();
protected: protected:
void NPUSortHelper(Node* node, // sort nodes to operational sequence
void SubgraphSortHelper(Node* node,
const std::unordered_set<Node*>& nodes_all, const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes, std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret); std::vector<Node*>* ret);
// nodes2cvt: op nodes to convert // nodes2cvt: op nodes to convert
// in_vars_name: graph op's inputs var name // cvted_vars: converted var nodes
// out_vars_name: graph op's outputs var name
// vcted_vars:
// nodes2rm: op nodes and var nodes that need to be removed // nodes2rm: op nodes and var nodes that need to be removed
void CvtOpNodes(const std::vector<Node*>& nodes2cvt, void CvtOpNodes(const std::vector<Node*>& nodes2cvt,
std::vector<std::string>* in_vars_name, lite::npu::bridge::node_map_type* cvted_vars);
std::vector<std::string>* out_vars_name,
lite::npu::bridge::node_map_type* cvted_vars, // achieve input and output vars/cvted_vars;
std::unordered_set<const Node*>* nodes2rm); // achieve all nodes to remove
void GetIOVars(const std::vector<Node*>& nodes2cvt,
const lite::npu::bridge::node_map_type& cvted_vars,
std::unordered_set<const Node*>* nodes2rm,
std::vector<Node*>* in_vars,
std::vector<Node*>* out_vars,
lite::npu::bridge::node_map_type* in_cvted_vars,
lite::npu::bridge::node_map_type* out_cvted_vars);
void GenNPUGraphOpNode(const std::unique_ptr<SSAGraph>& graph, void GenNPUGraphOpNode(const std::unique_ptr<SSAGraph>& graph,
int sub_id, int sub_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册