diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index c88d0ca8980096dcc0dbafb80f44990d5c3c6a3e..f61b8e1b253d044b25e90a0775bcd580028203d7 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -610,16 +610,19 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { } } -std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, +std::string CheckInputAndInsert(Scope* scope, + cpp::BlockDesc* block_desc, const std::string& input_name, const Type* tensor_type, const Type* subgraph_type) { auto cur_node = input_name; if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { auto layout_op = block_desc->AddOp(); - auto layout_arg_name = string_format("%s/layout", cur_node); + auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); + scope->Var(layout_arg_name); + VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name; layout_op->SetType("layout"); - layout_op->SetInput("X", {cur_node}); + layout_op->SetInput("Input", {cur_node}); layout_op->SetOutput("Out", {layout_arg_name}); cur_node = layout_arg_name; } @@ -627,7 +630,9 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, if (PrecisionCompatible(*tensor_type, *subgraph_type) && tensor_type->precision() != PRECISION(kInt8)) { auto cast_op = block_desc->AddOp(); - auto cast_arg_name = string_format("%s/cast", cur_node); + auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); + scope->Var(cast_arg_name); + VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name; cast_op->SetType("cast"); cast_op->SetAttr("in_dtype", 4); // FP32 cast_op->SetAttr("out_dtype", 5); // FP16 @@ -639,7 +644,8 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, return cur_node; } -std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, +std::string CheckOutputAndInsert(Scope* scope, + cpp::BlockDesc* block_desc, const std::string& output_name, const Type* tensor_type, const Type* subgraph_type) { @@ -655,16 +661,20 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, } if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { - auto layout_arg_name = string_format("%s/layout", cur_node); + auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); + scope->Var(layout_arg_name); + VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name; layout_op = block_desc->AddOp(); layout_op->SetType("layout"); - layout_op->SetInput("X", {layout_arg_name}); + layout_op->SetInput("Input", {layout_arg_name}); layout_op->SetOutput("Out", {cur_node}); cur_node = layout_arg_name; } if (cast_op) { - auto cast_arg_name = string_format("%s/cast", cur_node); + auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); + scope->Var(cast_arg_name); + VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name; cast_op->SetInput("X", {cast_arg_name}); cast_op->SetOutput("Out", {cur_node}); cur_node = cast_arg_name; @@ -675,7 +685,7 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, // insert cast op on mlu, to avoid cast on cpu, invoke before first run void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, - const Type* op_type) { + const Type* subgraph_type) { auto subgraph_op = subgraph_node->AsStmt().op(); CHECK_EQ(subgraph_op->Type(), "subgraph"); auto op = dynamic_cast(subgraph_op.get()); @@ -700,7 +710,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { i_names.emplace_back(input_name); node_replace[input_name] = CheckInputAndInsert( - new_block_desc, input_name, input->AsArg().type, op_type); + op->scope(), new_block_desc, input_name, input->AsArg().type, subgraph_type); } } for (auto& output : subgraph_node->outlinks) { @@ -708,7 +718,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { o_names.emplace_back(output_name); node_replace[output_name] = CheckOutputAndInsert( - block_desc, output_name, output->AsArg().type, op_type); + op->scope(), block_desc, output_name, output->AsArg().type, subgraph_type); } } @@ -749,8 +759,6 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, } } op->SetSubBlock(new_block_desc); - // set param to kernel - op->AttachKernel(op->GetKernel()); } void MLUPostprocessPass::Apply(const std::unique_ptr& graph) {