diff --git a/src/gopt/impl/tensor_reformat.cpp b/src/gopt/impl/tensor_reformat.cpp index 7d4241f3e102c022c8f3a794f0ae3df3a039a209..75fb9451754aa46aee69b5287e226c32b39c8964 100644 --- a/src/gopt/impl/tensor_reformat.cpp +++ b/src/gopt/impl/tensor_reformat.cpp @@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() { bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; - if (opr->input().size() >= 3) - check_dtype &= - new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; - if (opr->input().size() >= 4) - check_dtype &= - new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; + mgb_assert(opr->output().size() > 0); + bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32; + if (opr->input().size() >= 3) { + auto dtype_expect = dst_float ? DTypeEnum::Float32 + : DTypeEnum::QuantizedS32; + check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect; + } + if (opr->input().size() >= 4) { + check_dtype &= new_inp[3]->dtype().enumv() == + opr->output(0)->dtype().enumv(); + } if (!check_dtype) return nullptr; + size_t out_channels = opr->input(1)->shape()[0]; size_t in_channels = opr->input(1)->shape()[1]; bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0; @@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() { } } }; + for (size_t i = 0; i < inps.size(); ++i) { - inps[i] = process(i); + // do not format bias and z when dst_float is true + bool skip = dst_float && i >= 2; + if (!skip) inps[i] = process(i); } auto& conv_bias = opr->cast_final_safe(); - auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4); - format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); + auto ret = make_new_conv( + inps, &conv_bias, + dst_float ? Format::NCHW4_NCHW : Format::NCHW4); + if (!dst_float) + format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); return ret; };