提交 59aaf733 编写于 作者: J jackzhang235

change subgraph's inserted nodes sequence to: precision -> layout ->

target -> subgraph -> target -> precision -> layout
上级 3785aab8
......@@ -217,23 +217,23 @@ void MLUPostprocessPass::InsertBefore(SSAGraph* graph,
first_conv_nodes_.end(),
head_node->AsArg().name) != first_conv_nodes_.end();
// layout cast node
if (head_type->layout() != inst_type->layout()) {
// precision cast node
if (head_type->precision() != inst_type->precision() && !is_first_conv_head) {
cur_node = InsertCastBefore(
"layout",
name_prefix + "layout",
"cast",
name_prefix + "cast",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
head_type->target(), head_type->precision(), inst_type->layout()));
head_type->target(), inst_type->precision(), head_type->layout()));
}
// precision cast node
if (head_type->precision() != inst_type->precision() && !is_first_conv_head) {
// layout cast node
if (head_type->layout() != inst_type->layout()) {
cur_node = InsertCastBefore(
"cast",
name_prefix + "cast",
"layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
......@@ -366,23 +366,23 @@ void MLUPostprocessPass::InsertAfter(SSAGraph* graph,
const auto name_prefix =
tail_node->AsArg().name + string_format("_%p", inst_node) + "/trans_";
// layout cast node
if (tail_type->layout() != inst_type->layout()) {
// precision cast node
if (tail_type->precision() != inst_type->precision()) {
cur_node = InsertCastAfter(
"layout",
name_prefix + "layout",
"cast",
name_prefix + "cast",
graph,
cur_node,
inst_node,
LiteType::GetTensorTy(
tail_type->target(), tail_type->precision(), inst_type->layout()));
tail_type->target(), inst_type->precision(), tail_type->layout()));
}
// precision cast node
if (tail_type->precision() != inst_type->precision()) {
// layout cast node
if (tail_type->layout() != inst_type->layout()) {
cur_node = InsertCastAfter(
"cast",
name_prefix + "cast",
"layout",
name_prefix + "layout",
graph,
cur_node,
inst_node,
......
......@@ -48,11 +48,11 @@ REGISTER_LITE_KERNEL(
def_layout_nhwc2nchw_fp16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
PRECISION(kFP16),
DATALAYOUT(kNCHW))})
.Finalize();
......@@ -82,11 +82,11 @@ REGISTER_LITE_KERNEL(
def_layout_nchw2nhwc_fp16)
.BindInput("Input",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
PRECISION(kFP16),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kHost),
PRECISION(kFloat),
PRECISION(kFP16),
DATALAYOUT(kNHWC))})
.Finalize();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册