未验证 提交 1f01c4fd 编写于 作者: Z zhaoying9105 提交者: GitHub

(bugfix): emplace_back may reallocate, make cast_op ptr detached (#100)

* (bugfix): emplace_back may reallocate, make cast_op ptr detached

* fix: wrong op sequence will cause node in graph not added
Co-authored-by: N--get <zhaoying@cambricon.com>
Co-authored-by: Ndingminghui <dingminghui@cambricon.com>
上级 5bd0dfbb
...@@ -580,12 +580,12 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) { ...@@ -580,12 +580,12 @@ void MLUPostprocessPass::ModifyInputOutputDataType(SSAGraph* graph) {
out_arg.type = LiteType::GetTensorTy(TARGET(kHost), out_arg.type = LiteType::GetTensorTy(TARGET(kHost),
subgraph_arg_type->precision(), subgraph_arg_type->precision(),
DATALAYOUT(kNHWC)); DATALAYOUT(kNHWC));
VLOG(5) << "unused output node type: " << out_arg.name VLOG(4) << "unused output node type: " << out_arg.name
<< out_node_type->name(); << out_node_type->name();
} else { } else {
out_arg.type = LiteType::GetTensorTy( out_arg.type = LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)); TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
VLOG(5) << "output node type: " << out_arg.name VLOG(4) << "output node type: " << out_arg.name
<< out_node_type->name(); << out_node_type->name();
} }
} }
...@@ -665,7 +665,7 @@ std::pair<bool, std::string> CheckInputAndInsert(Scope* scope, ...@@ -665,7 +665,7 @@ std::pair<bool, std::string> CheckInputAndInsert(Scope* scope,
auto layout_op = block_desc->AddOp<cpp::OpDesc>(); auto layout_op = block_desc->AddOp<cpp::OpDesc>();
auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name); scope->Var(layout_arg_name);
VLOG(5) << "insert layout for subgraph input, arg tensor name: " VLOG(4) << "insert layout for subgraph input, arg tensor name: "
<< layout_arg_name; << layout_arg_name;
layout_op->SetType("layout"); layout_op->SetType("layout");
layout_op->SetInput("Input", {cur_node}); layout_op->SetInput("Input", {cur_node});
...@@ -680,7 +680,7 @@ std::pair<bool, std::string> CheckInputAndInsert(Scope* scope, ...@@ -680,7 +680,7 @@ std::pair<bool, std::string> CheckInputAndInsert(Scope* scope,
auto cast_op = block_desc->AddOp<cpp::OpDesc>(); auto cast_op = block_desc->AddOp<cpp::OpDesc>();
auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name); scope->Var(cast_arg_name);
VLOG(5) << "insert cast for subgraph input, arg tensor name: " VLOG(4) << "insert cast for subgraph input, arg tensor name: "
<< cast_arg_name; << cast_arg_name;
cast_op->SetType("cast"); cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 5); // FP32 cast_op->SetAttr<int>("in_dtype", 5); // FP32
...@@ -703,10 +703,13 @@ std::pair<bool, std::string> CheckOutputAndInsert( ...@@ -703,10 +703,13 @@ std::pair<bool, std::string> CheckOutputAndInsert(
auto cur_node = output_name; auto cur_node = output_name;
bool do_insert = false; bool do_insert = false;
cpp::OpDesc *layout_op = nullptr, *cast_op = nullptr; cpp::OpDesc *layout_op = nullptr, *cast_op = nullptr;
size_t cast_idx = 0;
// subgraph -> cast -> layout -> output // subgraph -> cast -> layout -> output
if (!PrecisionCompatible(*tensor_type, *subgraph_type)) { if (!PrecisionCompatible(*tensor_type, *subgraph_type)) {
cast_op = block_desc->AddOp<cpp::OpDesc>(); cast_op = block_desc->AddOp<cpp::OpDesc>();
cast_idx = block_desc->OpsSize() - 1;
CHECK_EQ(cast_op, block_desc->GetOp<cpp::OpDesc>(cast_idx));
cast_op->SetType("cast"); cast_op->SetType("cast");
cast_op->SetAttr<int>("in_dtype", 4); // FP16 cast_op->SetAttr<int>("in_dtype", 4); // FP16
cast_op->SetAttr<int>("out_dtype", 5); // FP32 cast_op->SetAttr<int>("out_dtype", 5); // FP32
...@@ -716,7 +719,7 @@ std::pair<bool, std::string> CheckOutputAndInsert( ...@@ -716,7 +719,7 @@ std::pair<bool, std::string> CheckOutputAndInsert(
if (!DataLayoutCompatible(*tensor_type, *subgraph_type)) { if (!DataLayoutCompatible(*tensor_type, *subgraph_type)) {
auto layout_arg_name = string_format("%s/layout", cur_node.c_str()); auto layout_arg_name = string_format("%s/layout", cur_node.c_str());
scope->Var(layout_arg_name); scope->Var(layout_arg_name);
VLOG(5) << "insert layout for subgraph output, arg tensor name: " VLOG(4) << "insert layout for subgraph output, arg tensor name: "
<< layout_arg_name; << layout_arg_name;
layout_op = block_desc->AddOp<cpp::OpDesc>(); layout_op = block_desc->AddOp<cpp::OpDesc>();
layout_op->SetType("layout"); layout_op->SetType("layout");
...@@ -727,9 +730,10 @@ std::pair<bool, std::string> CheckOutputAndInsert( ...@@ -727,9 +730,10 @@ std::pair<bool, std::string> CheckOutputAndInsert(
} }
if (cast_op) { if (cast_op) {
cast_op = block_desc->GetOp<cpp::OpDesc>(cast_idx);
auto cast_arg_name = string_format("%s/cast", cur_node.c_str()); auto cast_arg_name = string_format("%s/cast", cur_node.c_str());
scope->Var(cast_arg_name); scope->Var(cast_arg_name);
VLOG(5) << "insert cast for subgraph output, arg tensor name: " VLOG(4) << "insert cast for subgraph output, arg tensor name: "
<< cast_arg_name; << cast_arg_name;
cast_op->SetInput("X", {cast_arg_name}); cast_op->SetInput("X", {cast_arg_name});
cast_op->SetOutput("Out", {cur_node}); cast_op->SetOutput("Out", {cur_node});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册