提交 f56f187f 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mbg/gopt): fix nchw44-dot channel wise trans to nchw44

GitOrigin-RevId: aa2059a79601822131a29f8e23f6808847ca62c3
上级 af29fcb2
...@@ -168,10 +168,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( ...@@ -168,10 +168,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
for (auto nlmode : nonlinemode) for (auto nlmode : nonlinemode)
for (size_t n : {1, 2}) for (size_t n : {1, 2})
for (size_t kernel : kernel_vec) for (size_t kernel : kernel_vec)
for (size_t oc : {4, 12, 32}) for (size_t oc : {4, 32})
for (size_t ic : {1, 3, 4, 12, 32}) for (size_t ic : {1, 3, 4, 12, 32})
for (size_t h : {3, 5, 12}) for (size_t h : {3, 5, 12})
for (size_t w : {7, 16, 23}) { for (size_t w : {7, 23}) {
for (size_t group = 1; for (size_t group = 1;
group <= std::min(oc, ic); ++group) { group <= std::min(oc, ic); ++group) {
pack(n, oc, ic, h, w, kernel, stride, pack(n, oc, ic, h, w, kernel, stride,
...@@ -350,13 +350,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { ...@@ -350,13 +350,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), "F32DIRECT_SMALL_GROUP"); handle(), "F32DIRECT_SMALL_GROUP");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_1) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
check_conv_bias(get_nchw44_conv_bias_args({2, 7}, 1, false, false, false, check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, false, false,
false, false, false),
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false,
false, true, true), false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_2) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
check_conv_bias(get_nchw44_conv_bias_args({3, 5}, 1, false, false, false, check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false,
false, true, true), false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }
...@@ -388,9 +393,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { ...@@ -388,9 +393,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
false, true), false, true),
handle(), "F32_CONV_NCHW_NCHW44"); handle(), "F32_CONV_NCHW_NCHW44");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
check_conv_bias( check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, false), get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44"); handle(), "F32_CHANNEL_WISE_NCHW44");
} }
......
...@@ -2050,23 +2050,27 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2050,23 +2050,27 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
//! First is whether the conv can trans to nchwxx, second is the filter //! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode //! trans mode
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestTransResult = std::pair<TransType, RelayoutMode>; struct TestTransResult {
megdnn::param::ConvolutionV0::Format conv_dot_format = TransType trans_type;
megdnn::param::ConvBias::Format::NCHW44_DOT; RelayoutMode relayout_mod;
megdnn::param::ConvolutionV0::Format conv_format;
};
constexpr size_t pack_c_size = 4_z; constexpr size_t pack_c_size = 4_z;
auto test_trans_nchw44_dot = auto test_trans_nchw44_dot =
[](const megdnn::param::Convolution::Sparse conv_mode, [](const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> TestTransResult { const VarNode* filter) -> TestTransResult {
TestTransResult ret{TransType::TRANS_NONE, {}}; TestTransResult ret{TransType::TRANS_NONE, {}, {}};
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t IC = filter->shape()[1]; size_t IC = filter->shape()[1];
size_t OC = filter->shape()[0]; size_t OC = filter->shape()[0];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} else if (IC < pack_c_size && OC % pack_c_size == 0) { } else if (IC < pack_c_size && OC % pack_c_size == 0) {
ret.first = TransType::TRANS_HYBIRD_NCHWXX; ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} }
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
...@@ -2074,15 +2078,18 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2074,15 +2078,18 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
size_t ocpg = filter->shape()[1]; size_t ocpg = filter->shape()[1];
size_t icpg = filter->shape()[2]; size_t icpg = filter->shape()[2];
if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) { if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
ret.first = TransType::TRANS_NONE; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} }
} }
return ret; return ret;
}; };
auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format]( auto replace_conv_opr = [test_trans_nchw44_dot](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
...@@ -2094,7 +2101,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2094,7 +2101,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto is_trans = auto is_trans =
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]); test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]);
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) { if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5, new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode"); "The origin filter is not NCHW mode");
...@@ -2108,14 +2115,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2108,14 +2115,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
opr->config()); opr->config());
return new_opr; return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
//! filter trans to nchwxx mode //! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5, new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode"); "The origin filter is not NCHW mode");
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
auto new_filter = auto new_filter = RelayoutPlaceholder::make(new_inp[1],
RelayoutPlaceholder::make(new_inp[1], is_trans.second); is_trans.relayout_mod);
conv_filter = new_filter.node(); conv_filter = new_filter.node();
//! src trans to nchwxx mode //! src trans to nchwxx mode
if (new_inp[0]->shape().ndim != 5) { if (new_inp[0]->shape().ndim != 5) {
...@@ -2125,7 +2132,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2125,7 +2132,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
conv_src = new_src.node(); conv_src = new_src.node();
} }
auto new_param = conv_opr.param(); auto new_param = conv_opr.param();
new_param.format = conv_dot_format; new_param.format = is_trans.conv_format;
mgb_assert(conv_src->shape().ndim == 5 && mgb_assert(conv_src->shape().ndim == 5 &&
conv_filter->shape().ndim >= 6, conv_filter->shape().ndim >= 6,
"The conv src dim is not trans to nchwxx"); "The conv src dim is not trans to nchwxx");
...@@ -2137,16 +2144,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2137,16 +2144,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
"The conv dst dim is not trans to nchwxx"); "The conv dst dim is not trans to nchwxx");
return new_opr; return new_opr;
} else { } else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
auto new_filter = auto new_filter = RelayoutPlaceholder::make(new_inp[1],
RelayoutPlaceholder::make(new_inp[1], is_trans.second); is_trans.relayout_mod);
conv_filter = new_filter.node(); conv_filter = new_filter.node();
mgb_assert(conv_src->shape().ndim == 4 && mgb_assert(conv_src->shape().ndim == 4 &&
conv_filter->shape().ndim == 5, conv_filter->shape().ndim == 5,
"The src and filter is OK"); "The src and filter is OK");
auto new_param = conv_opr.param(); auto new_param = conv_opr.param();
new_param.format = conv_dot_format; new_param.format = is_trans.conv_format;
auto new_conv_opr = opr::Convolution::make( auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param, conv_src, conv_filter, new_param,
conv_opr.execution_policy(), conv_opr.config()); conv_opr.execution_policy(), conv_opr.config());
...@@ -2157,7 +2164,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2157,7 +2164,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
} }
}; };
auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format]( auto replace_conv_bias_opr = [test_trans_nchw44_dot](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
...@@ -2168,7 +2175,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2168,7 +2175,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto is_trans = auto is_trans =
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]); test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]);
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) { if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5, new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode"); "The origin filter is not NCHW mode");
...@@ -2188,15 +2195,15 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2188,15 +2195,15 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp, auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
opr->config()); opr->config());
return new_opr; return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) { } else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2]; *conv_bias_bias = new_inp[2];
//! filter trans to nchwxx mode //! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5, new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode"); "The origin filter is not NCHW mode");
auto new_filter = auto new_filter = RelayoutPlaceholder::make(new_inp[1],
RelayoutPlaceholder::make(new_inp[1], is_trans.second); is_trans.relayout_mod);
conv_bias_filter = new_filter.node(); conv_bias_filter = new_filter.node();
//! src trans to nchwxx mode //! src trans to nchwxx mode
if (new_inp[0]->shape().ndim != 5) { if (new_inp[0]->shape().ndim != 5) {
...@@ -2213,7 +2220,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2213,7 +2220,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
} }
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
new_param.format = conv_dot_format; new_param.format = is_trans.conv_format;
mgb_assert(conv_bias_src->shape().ndim == 5 && mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6, conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx"); "The conv_bias src dim is not trans to nchwxx");
...@@ -2225,11 +2232,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2225,11 +2232,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
"The conv_bias dst dim is not trans to nchwxx"); "The conv_bias dst dim is not trans to nchwxx");
return new_opr; return new_opr;
} else { } else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX); mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1], VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2]; *conv_bias_bias = new_inp[2];
auto new_filter = auto new_filter = RelayoutPlaceholder::make(new_inp[1],
RelayoutPlaceholder::make(new_inp[1], is_trans.second); is_trans.relayout_mod);
conv_bias_filter = new_filter.node(); conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale //! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) { if (new_inp[2]->shape().ndim == 4) {
...@@ -2242,7 +2249,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2242,7 +2249,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert((conv_bias_bias->shape().ndim == 5) || mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar()); conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
new_param.format = conv_dot_format; new_param.format = is_trans.conv_format;
auto new_conv_bias_opr = opr::ConvBias::make( auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param, conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config()); conv_bias_opr.execution_policy(), conv_bias_opr.config());
......
...@@ -2694,7 +2694,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -2694,7 +2694,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt).param().format); find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt).param().format); find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}}) graph->compile({{y_opt, {}}})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册