提交 2ab5c53f 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/gopt): support nhwc conv in tensor reformat pass

GitOrigin-RevId: 43e78d758ab352c9e47d9ca1bb5fe868d4443458
上级 009c90a2
......@@ -4618,6 +4618,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
......@@ -4679,6 +4684,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
......@@ -4741,6 +4751,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
......@@ -4800,6 +4815,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW32);
auto ovar = RelayoutPlaceholder::make(
......@@ -4818,10 +4838,75 @@ EnableNCHW64Pass::make_nchw64_converter() {
return ret;
};
auto try_transform_to_nhwc =
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
bool check_channels = out_channels % 8 == 0 && in_channels % 8 == 0;
if (!check_channels)
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC);
return ovar.node();
} else {
const auto& fmt = iter->second;
if (fmt == Format::NHWC) {
return inps[i];
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC);
return ovar.node();
} else if (fmt == Format::NCHW32) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC);
return ovar.node();
}
}
};
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret = make_new_conv(inps, &conv_bias, Format::NHWC);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NHWC));
return ret;
};
// replace rule for conv bias opr
auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4,
try_transform_to_nchw32,
try_transform_to_nchw64, try_transform_to_nchw](
try_transform_to_nchw64,
try_transform_to_nhwc, try_transform_to_nchw](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
using Param = megdnn::param::ConvBias;
......@@ -4833,7 +4918,8 @@ EnableNCHW64Pass::make_nchw64_converter() {
VarNode* new_var = nullptr;
if ((new_var = try_transform_to_nchw32(opr, new_inp)) ||
(new_var = try_transform_to_nchw4(opr, new_inp)) ||
(new_var = try_transform_to_nchw64(opr, new_inp))||
(new_var = try_transform_to_nchw64(opr, new_inp)) ||
(new_var = try_transform_to_nhwc(opr, new_inp)) ||
(new_var = try_transform_to_nchw(opr, new_inp))) {
return new_var->owner_opr();
} else {
......@@ -4891,6 +4977,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
NCHW_TO_NCHW4)
.node();
break;
case Format::NHWC:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
NCHW_TO_NHWC)
.node();
break;
case Format::NCHW32:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
......@@ -4991,6 +5083,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64),
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64),
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64),
cb(NCHW, NHWC), cb(NCHW4, NHWC), cb(NCHW32, NHWC),
cb(NCHW64, NHWC), cb(NHWC, NCHW), cb(NHWC, NCHW4),
cb(NHWC, NCHW32), cb(NHWC, NCHW64),
#undef cb
};
auto inps = new_inp;
......@@ -5037,26 +5132,27 @@ EnableNCHW64Pass::make_nchw64_converter() {
case Format::NCHW:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW64)
NCHW_TO_NHWC)
.node();
break;
case Format::NCHW4:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW64)
NCHW4_TO_NHWC)
.node();
break;
case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW64)
NCHW32_TO_NHWC)
.node();
break;
default:
mgb_assert(cur == Format::NCHW64);
mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC);
}
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = warp.param();
param.format = Format::NCHW64;
param.format = target_format;
SymbolVar new_warp;
if (inps.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make(
......@@ -5069,7 +5165,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
warp.config());
}
auto ret = new_warp.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW64));
format_map.insert(std::make_pair(ret, target_format));
return ret;
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur;
......@@ -5087,6 +5183,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
NCHW_TO_NCHW4)
.node();
break;
case Format::NHWC:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW4)
.node();
break;
case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
......@@ -5154,31 +5256,31 @@ EnableNCHW64Pass::make_nchw64_converter() {
case Format::NCHW:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW64)
NCHW_TO_NHWC)
.node();
break;
case Format::NCHW4:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW64)
NCHW4_TO_NHWC)
.node();
break;
case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW64)
NCHW32_TO_NHWC)
.node();
break;
default:
mgb_assert(cur == Format::NCHW64);
mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC);
}
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = pooling.param();
param.format = Format::NCHW64;
param.format = target_format;
auto new_pool =
opr::PoolingForward::make(inps[0], param, pooling.config());
auto ret = new_pool.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW64));
format_map.insert(std::make_pair(ret, target_format));
return ret;
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur;
......@@ -5188,12 +5290,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
} else {
cur = iter->second;
}
size_t in_channels = new_inp[0]->shape()[1];
bool use_nchw32 = false;
auto inps = new_inp;
using LayoutType = RelayoutPlaceholder::LayoutType;
switch (cur) {
case Format::NCHW: {
size_t in_channels = new_inp[0]->shape()[1];
use_nchw32 = in_channels % 32 == 0;
auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32
: LayoutType::NCHW_TO_NCHW4;
......@@ -5201,6 +5303,15 @@ EnableNCHW64Pass::make_nchw64_converter() {
.node();
break;
}
case Format::NHWC: {
size_t in_channels = new_inp[0]->shape()[3];
use_nchw32 = in_channels % 32 == 0;
auto layout_type = use_nchw32 ? LayoutType::NHWC_TO_NCHW32
: LayoutType::NHWC_TO_NCHW4;
inps[0] = RelayoutPlaceholder::make(inps[0], layout_type)
.node();
break;
}
case Format::NCHW64:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
......@@ -5253,6 +5364,13 @@ EnableNCHW64Pass::make_nchw64_converter() {
auto fmt = iter != format_map.end()?iter->second:Format::NCHW;
if (iter != format_map.end()) {
switch (fmt) {
case Format::NHWC:
inps[i] = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW)
.node();
break;
case Format::NCHW4:
inps[i] = RelayoutPlaceholder::make(
inps[i],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册