提交 c096096f 编写于 作者: D dingminghui 提交者: jackzhang235

fix(mlu): fix error while LITE_MLU_CAST is on

上级 9553fab2
......@@ -610,16 +610,19 @@ void MLUPostprocessPass::ModifyLayout(SSAGraph* graph) {
}
}
std::string CheckInputAndInsert(cpp::BlockDesc* block_desc,
std::string CheckInputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& input_name,
const Type* tensor_type,
const Type* subgraph_type) {
auto cur_node = input_name;
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_op = block_desc->AddOp<cpp::OpDesc>();
auto layout_arg_name = string_format("%s/layout", cur_node);
auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name;
layout_op->SetType("layout");
layout_op->SetInput("X", {cur_node});
layout_op->SetInput("Input", {cur_node});
layout_op->SetOutput("Out", {layout_arg_name});
cur_node = layout_arg_name;
}
......@@ -627,7 +630,9 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc,
if (PrecisionCompatible(*tensor_type, *subgraph_type) &&
tensor_type->precision() != PRECISION(kInt8)) {
auto cast_op = block_desc->AddOp<cpp::OpDesc>();
auto cast_arg_name = string_format("%s/cast", cur_node);
auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name);
VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name;
cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 4); // FP32
cast_op->SetAttr<int>("out_dtype", 5); // FP16
......@@ -639,7 +644,8 @@ std::string CheckInputAndInsert(cpp::BlockDesc* block_desc,
return cur_node;
}
std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
std::string CheckOutputAndInsert(Scope* scope,
cpp::BlockDesc* block_desc,
const std::string& output_name,
const Type* tensor_type,
const Type* subgraph_type) {
......@@ -655,16 +661,20 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
}
if (DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_arg_name = string_format("%s/layout", cur_node);
auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name);
VLOG(5) << "insert layout in subgraph, arg tensor name: " << layout_arg_name;
layout_op = block_desc->AddOp<cpp::OpDesc>();
layout_op->SetType("layout");
layout_op->SetInput("X", {layout_arg_name});
layout_op->SetInput("Input", {layout_arg_name});
layout_op->SetOutput("Out", {cur_node});
cur_node = layout_arg_name;
}
if (cast_op) {
auto cast_arg_name = string_format("%s/cast", cur_node);
auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name);
VLOG(5) << "insert cast in subgraph, arg tensor name: " << cast_arg_name;
cast_op->SetInput("X", {cast_arg_name});
cast_op->SetOutput("Out", {cur_node});
cur_node = cast_arg_name;
......@@ -675,7 +685,7 @@ std::string CheckOutputAndInsert(cpp::BlockDesc* block_desc,
// insert cast op on mlu, to avoid cast on cpu, invoke before first run
void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
const Type* op_type) {
const Type* subgraph_type) {
auto subgraph_op = subgraph_node->AsStmt().op();
CHECK_EQ(subgraph_op->Type(), "subgraph");
auto op = dynamic_cast<operators::SubgraphOp*>(subgraph_op.get());
......@@ -700,7 +710,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
if (!(input->AsArg().is_weight || input->AsArg().is_persist)) {
i_names.emplace_back(input_name);
node_replace[input_name] = CheckInputAndInsert(
new_block_desc, input_name, input->AsArg().type, op_type);
op->scope(), new_block_desc, input_name, input->AsArg().type, subgraph_type);
}
}
for (auto& output : subgraph_node->outlinks) {
......@@ -708,7 +718,7 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
if (!(output->AsArg().is_weight || output->AsArg().is_persist)) {
o_names.emplace_back(output_name);
node_replace[output_name] = CheckOutputAndInsert(
block_desc, output_name, output->AsArg().type, op_type);
op->scope(), block_desc, output_name, output->AsArg().type, subgraph_type);
}
}
......@@ -749,8 +759,6 @@ void MLUPostprocessPass::AdjustSubgraph(Node* subgraph_node,
}
}
op->SetSubBlock(new_block_desc);
// set param to kernel
op->AttachKernel(op->GetKernel());
}
void MLUPostprocessPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册