From c096096f63b0fdceb4b8112f223b8c0250264cfc Mon Sep 17 00:00:00 2001 From: dingminghui Date: Wed, 13 May 2020 15:44:50 +0800 Subject: [PATCH] fix(mlu): fix error while LITE_MLU_CAST is on --- lite/core/mir/mlu_postprocess_pass.cc | 34 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/lite/core/mir/mlu_postprocess_pass.cc b/lite/core/mir/mlu_postprocess_pass.cc index c88d0ca898..f61b8e1b25 100644 --- a/lite/core/mir/mlu_postprocess_pass.cc +++ b/lite/core/mir/mlu_postprocess_pass.cc @@ -610,16 +610,19 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { } } -std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, +std::string CheckInputAndInsert(Scope* scope, + cpp::BlockDesc* block_desc, const std::string& input_name, const Type* tensor_type, const Type* subgraph_type) { auto cur_node = input_name; if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { auto layout_op = block_desc->AddOp(); - auto layout_arg_name = string_format("%s/layout", cur_node); + auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); + scope->Var(layout_arg_name); + VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name; layout_op->SetType("layout"); - layout_op->SetInput("X", {cur_node}); + layout_op->SetInput("Input", {cur_node}); layout_op->SetOutput("Out", {layout_arg_name}); cur_node = layout_arg_name; } @@ -627,7 +630,9 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, if (PrecisionCompatible(*tensor_type, *subgraph_type) && tensor_type->precision() != PRECISION(kInt8)) { auto cast_op = block_desc->AddOp(); - auto cast_arg_name = string_format("%s/cast", cur_node); + auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); + scope->Var(cast_arg_name); + VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name; cast_op->SetType("cast"); cast_op->SetAttr("in_dtype", 4); // FP32 cast_op->SetAttr("out_dtype", 5); // FP16 @@ -639,7 +644,8 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, return cur_node; } -std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, +std::string CheckOutputAndInsert(Scope* scope, + cpp::BlockDesc* block_desc, const std::string& output_name, const Type* tensor_type, const Type* subgraph_type) { @@ -655,16 +661,20 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, } if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { - auto layout_arg_name = string_format("%s/layout", cur_node); + auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); + scope->Var(layout_arg_name); + VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name; layout_op = block_desc->AddOp(); layout_op->SetType("layout"); - layout_op->SetInput("X", {layout_arg_name}); + layout_op->SetInput("Input", {layout_arg_name}); layout_op->SetOutput("Out", {cur_node}); cur_node = layout_arg_name; } if (cast_op) { - auto cast_arg_name = string_format("%s/cast", cur_node); + auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); + scope->Var(cast_arg_name); + VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name; cast_op->SetInput("X", {cast_arg_name}); cast_op->SetOutput("Out", {cur_node}); cur_node = cast_arg_name; @@ -675,7 +685,7 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, // insert cast op on mlu, to avoid cast on cpu, invoke before first run void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, - const Type* op_type) { + const Type* subgraph_type) { auto subgraph_op = subgraph_node->AsStmt().op(); CHECK_EQ(subgraph_op->Type(), "subgraph"); auto op = dynamic_cast(subgraph_op.get()); @@ -700,7 +710,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { i_names.emplace_back(input_name); node_replace[input_name] = CheckInputAndInsert( - new_block_desc, input_name, input->AsArg().type, op_type); + op->scope(), new_block_desc, input_name, input->AsArg().type, subgraph_type); } } for (auto& output : subgraph_node->outlinks) { @@ -708,7 +718,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { o_names.emplace_back(output_name); node_replace[output_name] = CheckOutputAndInsert( - block_desc, output_name, output->AsArg().type, op_type); + op->scope(), block_desc, output_name, output->AsArg().type, subgraph_type); } } @@ -749,8 +759,6 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, } } op->SetSubBlock(new_block_desc); - // set param to kernel - op->AttachKernel(op->GetKernel()); } void MLUPostprocessPass::Apply(const std::unique_ptr& graph) { -- GitLab