From 6d686ff26febc3a90748147749cdda49d5d88e1a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 16 Jun 2021 15:56:03 +0800 Subject: [PATCH] feat(gopt/inference): allow Float32 output dtype in EnableNCHW64Pass GitOrigin-RevId: 1891efb76f66a6abbd0a56820281b4fe91e70304 --- src/gopt/impl/tensor_reformat.cpp | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 7d4241f3e..75fb94517 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() { bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; - if (opr->input().size() >= 3) - check_dtype &= - new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; - if (opr->input().size() >= 4) - check_dtype &= - new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; + mgb_assert(opr->output().size() > 0); + bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; + if (opr->input().size() >= 3) { + auto dtype_expect = dst_float ? DTypeEnum::Float32 + : DTypeEnum::QuantizedS32; + check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; + } + if (opr->input().size() >= 4) { + check_dtype &= new_inp[3]->dtype().enumv() == + opr->output(0)->dtype().enumv(); + } if (!check_dtype) return nullptr; + size_t out_channels = opr->input(1)->shape()[0]; size_t in_channels = opr->input(1)->shape()[1]; bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0; @@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() { } } }; + for (size_t i = 0; i < inps.size(); ++i) { - inps[i] = process(i); + // do not format bias and z when dst_float is true + bool skip = dst_float && i >= 2; + if (!skip) inps[i] = process(i); } auto& conv_bias = opr->cast_final_safe(); - auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4); - format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); + auto ret = make_new_conv( + inps, &conv_bias, + dst_float ? Format::NCHW4_NCHW : Format::NCHW4); + if (!dst_float) + format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); return ret; }; -- GitLab