提交 c096096f 编写于 作者: D dingminghui 提交者: jackzhang235

fix(mlu): fix error while LITE_MLU_CAST is on

上级 9553fab2
...@@ -610,16 +610,19 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) { ...@@ -610,16 +610,19 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
} }
} }
std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, std::string CheckInputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& input_name, const std::string& input_name,
const Type* tensor_type, const Type* tensor_type,
const Type* subgraph_type) { const Type* subgraph_type) {
auto cur_node = input_name; auto cur_node = input_name;
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_op = block_desc->AddOp<cpp::OpDesc>(); auto layout_op = block_desc->AddOp<cpp::OpDesc>();
auto layout_arg_name = string_format("%s/layout", cur_node); auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name;
layout_op->SetType("layout"); layout_op->SetType("layout");
layout_op->SetInput("X", {cur_node}); layout_op->SetInput("Input", {cur_node});
layout_op->SetOutput("Out", {layout_arg_name}); layout_op->SetOutput("Out", {layout_arg_name});
cur_node = layout_arg_name; cur_node = layout_arg_name;
} }
...@@ -627,7 +630,9 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, ...@@ -627,7 +630,9 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc,
if (PrecisionCompatible(*tensor_type, *subgraph_type) && if (PrecisionCompatible(*tensor_type, *subgraph_type) &&
tensor_type->precision() != PRECISION(kInt8)) { tensor_type->precision() != PRECISION(kInt8)) {
auto cast_op = block_desc->AddOp<cpp::OpDesc>(); auto cast_op = block_desc->AddOp<cpp::OpDesc>();
auto cast_arg_name = string_format("%s/cast", cur_node); auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name);
VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name;
cast_op->SetType("cast"); cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 4); // FP32 cast_op->SetAttr<int>("in_dtype", 4); // FP32
cast_op->SetAttr<int>("out_dtype", 5); // FP16 cast_op->SetAttr<int>("out_dtype", 5); // FP16
...@@ -639,7 +644,8 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc, ...@@ -639,7 +644,8 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc,
return cur_node; return cur_node;
} }
std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, std::string CheckOutputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& output_name, const std::string& output_name,
const Type* tensor_type, const Type* tensor_type,
const Type* subgraph_type) { const Type* subgraph_type) {
...@@ -655,16 +661,20 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, ...@@ -655,16 +661,20 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
} }
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) { if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_arg_name = string_format("%s/layout", cur_node); auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name;
layout_op = block_desc->AddOp<cpp::OpDesc>(); layout_op = block_desc->AddOp<cpp::OpDesc>();
layout_op->SetType("layout"); layout_op->SetType("layout");
layout_op->SetInput("X", {layout_arg_name}); layout_op->SetInput("Input", {layout_arg_name});
layout_op->SetOutput("Out", {cur_node}); layout_op->SetOutput("Out", {cur_node});
cur_node = layout_arg_name; cur_node = layout_arg_name;
} }
if (cast_op) { if (cast_op) {
auto cast_arg_name = string_format("%s/cast", cur_node); auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name);
VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name;
cast_op->SetInput("X", {cast_arg_name}); cast_op->SetInput("X", {cast_arg_name});
cast_op->SetOutput("Out", {cur_node}); cast_op->SetOutput("Out", {cur_node});
cur_node = cast_arg_name; cur_node = cast_arg_name;
...@@ -675,7 +685,7 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc, ...@@ -675,7 +685,7 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
// insert cast op on mlu, to avoid cast on cpu, invoke before first run // insert cast op on mlu, to avoid cast on cpu, invoke before first run
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* op_type) { const Type* subgraph_type) {
auto subgraph_op = subgraph_node->AsStmt().op(); auto subgraph_op = subgraph_node->AsStmt().op();
CHECK_EQ(subgraph_op->Type(), "subgraph"); CHECK_EQ(subgraph_op->Type(), "subgraph");
auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get()); auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get());
...@@ -700,7 +710,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -700,7 +710,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) { if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name); i_names.emplace_back(input_name);
node_replace[input_name] = CheckInputAndInsert( node_replace[input_name] = CheckInputAndInsert(
new_block_desc, input_name, input->AsArg().type, op_type); op->scope(), new_block_desc, input_name, input->AsArg().type, subgraph_type);
} }
} }
for (auto& output : subgraph_node->outlinks) { for (auto& output : subgraph_node->outlinks) {
...@@ -708,7 +718,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -708,7 +718,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) { if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name); o_names.emplace_back(output_name);
node_replace[output_name] = CheckOutputAndInsert( node_replace[output_name] = CheckOutputAndInsert(
block_desc, output_name, output->AsArg().type, op_type); op->scope(), block_desc, output_name, output->AsArg().type, subgraph_type);
} }
} }
...@@ -749,8 +759,6 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node, ...@@ -749,8 +759,6 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
} }
} }
op->SetSubBlock(new_block_desc); op->SetSubBlock(new_block_desc);
// set param to kernel
op->AttachKernel(op->GetKernel());
} }
void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) { void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册