提交 2ab5c53f 编写于 作者: M Megvii Engine Team 提交者: huangxinda

feat(mgb/gopt): support nhwc conv in tensor reformat pass

GitOrigin-RevId: 43e78d758ab352c9e47d9ca1bb5fe868d4443458
上级 009c90a2
...@@ -4618,6 +4618,11 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4618,6 +4618,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i], inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW); RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW);
return ovar.node(); return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW);
return ovar.node();
} else { } else {
mgb_assert(fmt == Format::NCHW64); mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make( auto ovar = RelayoutPlaceholder::make(
...@@ -4679,6 +4684,11 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4679,6 +4684,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i], inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4); RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
return ovar.node(); return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4);
return ovar.node();
} else { } else {
mgb_assert(fmt == Format::NCHW64); mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make( auto ovar = RelayoutPlaceholder::make(
...@@ -4741,6 +4751,11 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4741,6 +4751,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i], inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32); RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
return ovar.node(); return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32);
return ovar.node();
} else { } else {
mgb_assert(fmt == Format::NCHW64); mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make( auto ovar = RelayoutPlaceholder::make(
...@@ -4800,6 +4815,11 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4800,6 +4815,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i], inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64);
return ovar.node(); return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64);
return ovar.node();
} else { } else {
mgb_assert(fmt == Format::NCHW32); mgb_assert(fmt == Format::NCHW32);
auto ovar = RelayoutPlaceholder::make( auto ovar = RelayoutPlaceholder::make(
...@@ -4818,10 +4838,75 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4818,10 +4838,75 @@ EnableNCHW64Pass::make_nchw64_converter() {
return ret; return ret;
}; };
auto try_transform_to_nhwc =
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
bool check_channels = out_channels % 8 == 0 && in_channels % 8 == 0;
if (!check_channels)
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC);
return ovar.node();
} else {
const auto& fmt = iter->second;
if (fmt == Format::NHWC) {
return inps[i];
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC);
return ovar.node();
} else if (fmt == Format::NCHW32) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC);
return ovar.node();
}
}
};
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret = make_new_conv(inps, &conv_bias, Format::NHWC);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NHWC));
return ret;
};
// replace rule for conv bias opr // replace rule for conv bias opr
auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4, auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4,
try_transform_to_nchw32, try_transform_to_nchw32,
try_transform_to_nchw64, try_transform_to_nchw]( try_transform_to_nchw64,
try_transform_to_nhwc, try_transform_to_nchw](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
using Param = megdnn::param::ConvBias; using Param = megdnn::param::ConvBias;
...@@ -4833,7 +4918,8 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4833,7 +4918,8 @@ EnableNCHW64Pass::make_nchw64_converter() {
VarNode* new_var = nullptr; VarNode* new_var = nullptr;
if ((new_var = try_transform_to_nchw32(opr, new_inp)) || if ((new_var = try_transform_to_nchw32(opr, new_inp)) ||
(new_var = try_transform_to_nchw4(opr, new_inp)) || (new_var = try_transform_to_nchw4(opr, new_inp)) ||
(new_var = try_transform_to_nchw64(opr, new_inp))|| (new_var = try_transform_to_nchw64(opr, new_inp)) ||
(new_var = try_transform_to_nhwc(opr, new_inp)) ||
(new_var = try_transform_to_nchw(opr, new_inp))) { (new_var = try_transform_to_nchw(opr, new_inp))) {
return new_var->owner_opr(); return new_var->owner_opr();
} else { } else {
...@@ -4891,6 +4977,12 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4891,6 +4977,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
NCHW_TO_NCHW4) NCHW_TO_NCHW4)
.node(); .node();
break; break;
case Format::NHWC:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
NCHW_TO_NHWC)
.node();
break;
case Format::NCHW32: case Format::NCHW32:
inps[1] = RelayoutPlaceholder::make( inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType:: inps[1], RelayoutPlaceholder::LayoutType::
...@@ -4991,6 +5083,9 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -4991,6 +5083,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64), cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64),
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64),
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64), cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64),
cb(NCHW, NHWC), cb(NCHW4, NHWC), cb(NCHW32, NHWC),
cb(NCHW64, NHWC), cb(NHWC, NCHW), cb(NHWC, NCHW4),
cb(NHWC, NCHW32), cb(NHWC, NCHW64),
#undef cb #undef cb
}; };
auto inps = new_inp; auto inps = new_inp;
...@@ -5037,26 +5132,27 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5037,26 +5132,27 @@ EnableNCHW64Pass::make_nchw64_converter() {
case Format::NCHW: case Format::NCHW:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW64) NCHW_TO_NHWC)
.node(); .node();
break; break;
case Format::NCHW4: case Format::NCHW4:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW64) NCHW4_TO_NHWC)
.node(); .node();
break; break;
case Format::NCHW32: case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW64) NCHW32_TO_NHWC)
.node(); .node();
break; break;
default: default:
mgb_assert(cur == Format::NCHW64); mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC);
} }
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = warp.param(); auto param = warp.param();
param.format = Format::NCHW64; param.format = target_format;
SymbolVar new_warp; SymbolVar new_warp;
if (inps.size() == 3) { if (inps.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make( new_warp = opr::WarpPerspectiveForward::make(
...@@ -5069,7 +5165,7 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5069,7 +5165,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
warp.config()); warp.config());
} }
auto ret = new_warp.node()->owner_opr(); auto ret = new_warp.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW64)); format_map.insert(std::make_pair(ret, target_format));
return ret; return ret;
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { } else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur; Format cur;
...@@ -5087,6 +5183,12 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5087,6 +5183,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
NCHW_TO_NCHW4) NCHW_TO_NCHW4)
.node(); .node();
break; break;
case Format::NHWC:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW4)
.node();
break;
case Format::NCHW32: case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
...@@ -5154,31 +5256,31 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5154,31 +5256,31 @@ EnableNCHW64Pass::make_nchw64_converter() {
case Format::NCHW: case Format::NCHW:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW64) NCHW_TO_NHWC)
.node(); .node();
break; break;
case Format::NCHW4: case Format::NCHW4:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW64) NCHW4_TO_NHWC)
.node(); .node();
break; break;
case Format::NCHW32: case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW64) NCHW32_TO_NHWC)
.node(); .node();
break; break;
default: default:
mgb_assert(cur == Format::NCHW64); mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC);
} }
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = pooling.param(); auto param = pooling.param();
param.format = Format::NCHW64; param.format = target_format;
auto new_pool = auto new_pool =
opr::PoolingForward::make(inps[0], param, pooling.config()); opr::PoolingForward::make(inps[0], param, pooling.config());
auto ret = new_pool.node()->owner_opr(); auto ret = new_pool.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW64)); format_map.insert(std::make_pair(ret, target_format));
return ret; return ret;
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) { } else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur; Format cur;
...@@ -5188,12 +5290,12 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5188,12 +5290,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
} else { } else {
cur = iter->second; cur = iter->second;
} }
size_t in_channels = new_inp[0]->shape()[1];
bool use_nchw32 = false; bool use_nchw32 = false;
auto inps = new_inp; auto inps = new_inp;
using LayoutType = RelayoutPlaceholder::LayoutType; using LayoutType = RelayoutPlaceholder::LayoutType;
switch (cur) { switch (cur) {
case Format::NCHW: { case Format::NCHW: {
size_t in_channels = new_inp[0]->shape()[1];
use_nchw32 = in_channels % 32 == 0; use_nchw32 = in_channels % 32 == 0;
auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32 auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32
: LayoutType::NCHW_TO_NCHW4; : LayoutType::NCHW_TO_NCHW4;
...@@ -5201,6 +5303,15 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5201,6 +5303,15 @@ EnableNCHW64Pass::make_nchw64_converter() {
.node(); .node();
break; break;
} }
case Format::NHWC: {
size_t in_channels = new_inp[0]->shape()[3];
use_nchw32 = in_channels % 32 == 0;
auto layout_type = use_nchw32 ? LayoutType::NHWC_TO_NCHW32
: LayoutType::NHWC_TO_NCHW4;
inps[0] = RelayoutPlaceholder::make(inps[0], layout_type)
.node();
break;
}
case Format::NCHW64: case Format::NCHW64:
inps[0] = RelayoutPlaceholder::make( inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType:: inps[0], RelayoutPlaceholder::LayoutType::
...@@ -5253,6 +5364,13 @@ EnableNCHW64Pass::make_nchw64_converter() { ...@@ -5253,6 +5364,13 @@ EnableNCHW64Pass::make_nchw64_converter() {
auto fmt = iter != format_map.end()?iter->second:Format::NCHW; auto fmt = iter != format_map.end()?iter->second:Format::NCHW;
if (iter != format_map.end()) { if (iter != format_map.end()) {
switch (fmt) { switch (fmt) {
case Format::NHWC:
inps[i] = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW)
.node();
break;
case Format::NCHW4: case Format::NCHW4:
inps[i] = RelayoutPlaceholder::make( inps[i] = RelayoutPlaceholder::make(
inps[i], inps[i],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册