提交 642e72b7 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support generate multiple IO subgraph (#1828)

test=develop
上级 f7d498f3
......@@ -37,7 +37,7 @@ namespace lite {
namespace mir {
namespace subgraph {
void GenerateNPUProgramPass::NPUSortHelper(
void GenerateNPUProgramPass::SubgraphSortHelper(
Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
......@@ -46,7 +46,7 @@ void GenerateNPUProgramPass::NPUSortHelper(
if (var_node->inlinks.empty()) continue;
auto* op_node = var_node->inlinks.front();
if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) {
NPUSortHelper(op_node, nodes_all, visited_nodes, ret);
SubgraphSortHelper(op_node, nodes_all, visited_nodes, ret);
}
}
ret->push_back(node);
......@@ -55,40 +55,68 @@ void GenerateNPUProgramPass::NPUSortHelper(
void GenerateNPUProgramPass::CvtOpNodes(
const std::vector<Node*>& nodes2cvt,
std::vector<std::string>* in_vars_name,
std::vector<std::string>* out_vars_name,
lite::npu::bridge::node_map_type* cvted_vars,
std::unordered_set<const Node*>* nodes2rm) {
lite::npu::bridge::node_map_type* cvted_vars) {
const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions();
// record all converted vars
// op node's inputs must be found in cvted_vars
for (auto& node : nodes2cvt) {
lite::npu::bridge::node_map_type node_inputs;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
if (arg.is_weight) continue;
auto var_name = arg.name;
if (!cvted_vars->count(var_name)) {
if (arg.is_weight) continue;
cvted_vars->insert(std::make_pair(
var_name,
lite::npu::bridge::CvtNode(var_node, stmt.op()->scope())));
in_vars_name->push_back(var_name);
}
node_inputs.insert(*cvted_vars->find(var_name));
}
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs);
cvted_vars->insert(node_outputs.begin(), node_outputs.end());
nodes2rm->insert(node);
for (auto& var_node : node->outlinks) {
for (auto& next_op_node : var_node->outlinks) {
if (std::find(nodes2cvt.begin(), nodes2cvt.end(), next_op_node) ==
nodes2cvt.end()) {
out_vars_name->push_back(var_node->AsArg().name);
break;
}
}
}
void GenerateNPUProgramPass::GetIOVars(
const std::vector<Node*>& nodes2cvt,
const lite::npu::bridge::node_map_type& cvted_vars,
std::unordered_set<const Node*>* nodes2rm,
std::vector<Node*>* in_vars,
std::vector<Node*>* out_vars,
lite::npu::bridge::node_map_type* in_cvted_vars,
lite::npu::bridge::node_map_type* out_cvted_vars) {
std::unordered_set<Node*> op_nodes_all(nodes2cvt.begin(), nodes2cvt.end());
for (auto& op_node : nodes2cvt) {
for (auto& in_var : op_node->inlinks) {
if (in_var->AsArg().is_weight) continue;
auto* pre_op_node = in_var->inlinks.front();
if (op_nodes_all.count(pre_op_node)) {
nodes2rm->insert(in_var);
continue;
}
in_vars->push_back(in_var);
auto arg_name = in_var->AsArg().name;
in_cvted_vars->insert(std::make_pair(arg_name, cvted_vars.at(arg_name)));
}
for (auto& out_var : op_node->outlinks) {
if (out_var->outlinks.empty()) {
nodes2rm->insert(out_var);
continue;
}
auto* next_op_node = out_var->outlinks.front();
if (op_nodes_all.count(next_op_node)) {
nodes2rm->insert(out_var);
continue;
}
out_vars->push_back(out_var);
auto arg_name = out_var->AsArg().name;
out_cvted_vars->insert(std::make_pair(arg_name, cvted_vars.at(arg_name)));
}
}
nodes2rm->insert(nodes2cvt.begin(), nodes2cvt.end());
}
void GenerateNPUProgramPass::GenNPUGraphOpNode(
......@@ -100,23 +128,38 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
for (auto& node : nodes_all) {
if (!node->IsStmt()) continue;
if (visited_nodes.count(node)) continue;
NPUSortHelper(node, nodes_all, &visited_nodes, &ret);
SubgraphSortHelper(node, nodes_all, &visited_nodes, &ret);
}
std::vector<std::string> in_vars_name;
std::vector<std::string> out_vars_name;
lite::npu::bridge::node_map_type cvted_vars;
CvtOpNodes(ret, &cvted_vars);
std::unordered_set<const Node*> nodes2rm;
CvtOpNodes(ret, &in_vars_name, &out_vars_name, &cvted_vars, &nodes2rm);
// insert new graph op node
std::vector<Node*> in_vars;
std::vector<Node*> out_vars;
lite::npu::bridge::node_map_type in_cvted_vars;
lite::npu::bridge::node_map_type out_cvted_vars;
GetIOVars(ret,
cvted_vars,
&nodes2rm,
&in_vars,
&out_vars,
&in_cvted_vars,
&out_cvted_vars);
std::vector<std::string> in_vars_name;
std::vector<std::string> out_vars_name;
std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs;
for (auto i : in_vars_name) {
inputs.push_back(*cvted_vars.at(i));
for (auto i : in_cvted_vars) {
in_vars_name.push_back(i.first);
inputs.push_back(*i.second);
}
for (auto i : out_vars_name) {
outputs.push_back(*cvted_vars.at(i));
for (auto i : out_cvted_vars) {
out_vars_name.push_back(i.first);
outputs.push_back(*i.second);
}
std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
LOG(FATAL) << "Build NPU failed subgraph " << sub_id;
......@@ -125,27 +168,25 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
std::vector<std::string> in_var_names;
op_desc.SetInput("Inputs", in_vars_name);
op_desc.SetOutput("Outputs", out_vars_name);
op_desc.SetAttr("model_name", model_name);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
// TODO(zpy): support multi inputs op
auto start_op = ret.front()->AsStmt().op();
auto* scope = start_op->scope();
auto any_op = ret.front()->AsStmt().op();
auto* scope = any_op->scope();
graph_op->Attach(op_desc, scope);
auto valid_places = start_op->valid_places();
auto valid_places = any_op->valid_places();
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& var_node : ret.front()->inlinks) {
auto& arg = var_node->AsArg();
if (arg.is_weight) continue;
IR_NODE_LINK_TO(var_node, new_op_node);
for (auto& in_var : in_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& var_node : ret.back()->outlinks) {
auto& arg = var_node->AsArg();
if (arg.is_weight) continue;
IR_NODE_LINK_TO(var_node, new_op_node);
for (auto& out_var : out_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
// assign context
......@@ -159,8 +200,10 @@ void GenerateNPUProgramPass::GenNPUGraphOpNode(
void GenerateNPUProgramPass::ConvertSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) {
std::unordered_map<int, std::unordered_set<Node*>> nodes_all;
int ops_num = 0;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
ops_num++;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
......@@ -178,6 +221,7 @@ void GenerateNPUProgramPass::ConvertSubgraph(
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
......@@ -215,5 +259,3 @@ std::unique_ptr<RuntimeProgram> GenerateNPUProgramPass::GenProgram() {
REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass);
// USE_LITE_OP(graph_op);
......@@ -38,21 +38,27 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
std::unique_ptr<RuntimeProgram> GenProgram();
protected:
void NPUSortHelper(Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret);
// sort nodes to operational sequence
void SubgraphSortHelper(Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret);
// nodes2cvt: op nodes to convert
// in_vars_name: graph op's inputs var name
// out_vars_name: graph op's outputs var name
// vcted_vars:
// cvted_vars: converted var nodes
// nodes2rm: op nodes and var nodes that need to be removed
void CvtOpNodes(const std::vector<Node*>& nodes2cvt,
std::vector<std::string>* in_vars_name,
std::vector<std::string>* out_vars_name,
lite::npu::bridge::node_map_type* cvted_vars,
std::unordered_set<const Node*>* nodes2rm);
lite::npu::bridge::node_map_type* cvted_vars);
// achieve input and output vars/cvted_vars;
// achieve all nodes to remove
void GetIOVars(const std::vector<Node*>& nodes2cvt,
const lite::npu::bridge::node_map_type& cvted_vars,
std::unordered_set<const Node*>* nodes2rm,
std::vector<Node*>* in_vars,
std::vector<Node*>* out_vars,
lite::npu::bridge::node_map_type* in_cvted_vars,
lite::npu::bridge::node_map_type* out_cvted_vars);
void GenNPUGraphOpNode(const std::unique_ptr<SSAGraph>& graph,
int sub_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册