提交 6d686ff2 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(gopt/inference): allow Float32 output dtype in EnableNCHW64Pass

GitOrigin-RevId: 1891efb76f66a6abbd0a56820281b4fe91e70304
上级 7d3df995
...@@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4330,14 +4330,20 @@ EnableNCHW64Pass::make_nchw64_converter() {
bool check_dtype = bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 && new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8; new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (opr->input().size() >= 3) mgb_assert(opr->output().size() > 0);
check_dtype &= bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32; if (opr->input().size() >= 3) {
if (opr->input().size() >= 4) auto dtype_expect = dst_float ? DTypeEnum::Float32
check_dtype &= : DTypeEnum::QuantizedS32;
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8; check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect;
}
if (opr->input().size() >= 4) {
check_dtype &= new_inp[3]->dtype().enumv() ==
opr->output(0)->dtype().enumv();
}
if (!check_dtype) if (!check_dtype)
return nullptr; return nullptr;
size_t out_channels = opr->input(1)->shape()[0]; size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1]; size_t in_channels = opr->input(1)->shape()[1];
bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0; bool check_channels = out_channels % 4 == 0 && in_channels % 4 == 0;
...@@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4370,12 +4376,18 @@ EnableNCHW64Pass::make_nchw64_converter() {
} }
} }
}; };
for (size_t i = 0; i < inps.size(); ++i) { for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i); // do not format bias and z when dst_float is true
bool skip = dst_float && i >= 2;
if (!skip) inps[i] = process(i);
} }
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4); auto ret = make_new_conv(
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4)); inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
if (!dst_float)
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
return ret; return ret;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册