提交 f56f187f 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mbg/gopt): fix nchw44-dot channel wise trans to nchw44

GitOrigin-RevId: aa2059a79601822131a29f8e23f6808847ca62c3
上级 af29fcb2
......@@ -168,10 +168,10 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
for (auto nlmode : nonlinemode)
for (size_t n : {1, 2})
for (size_t kernel : kernel_vec)
for (size_t oc : {4, 12, 32})
for (size_t oc : {4, 32})
for (size_t ic : {1, 3, 4, 12, 32})
for (size_t h : {3, 5, 12})
for (size_t w : {7, 16, 23}) {
for (size_t w : {7, 23}) {
for (size_t group = 1;
group <= std::min(oc, ic); ++group) {
pack(n, oc, ic, h, w, kernel, stride,
......@@ -350,13 +350,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), "F32DIRECT_SMALL_GROUP");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_1) {
check_conv_bias(get_nchw44_conv_bias_args({2, 7}, 1, false, false, false,
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
check_conv_bias(get_nchw44_conv_bias_args({7}, 1, false, false, false,
false, false, false),
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false,
false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_2) {
check_conv_bias(get_nchw44_conv_bias_args({3, 5}, 1, false, false, false,
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false,
false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT");
}
......@@ -388,11 +393,15 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
false, true),
handle(), "F32_CONV_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_1) {
check_conv_bias(
get_nchw44_channel_wise_args({2, 3, 5}, 1, false, false, false),
get_nchw44_channel_wise_args({2, 3}, 1, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE1_FP32_NCHW44_2) {
check_conv_bias(get_nchw44_channel_wise_args({5}, 1, false, false, false),
handle(), "F32_CHANNEL_WISE_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CHANNEL_WISE_STRIDE2_FP32_NCHW44) {
check_conv_bias(
......
......@@ -2050,23 +2050,27 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestTransResult = std::pair<TransType, RelayoutMode>;
megdnn::param::ConvolutionV0::Format conv_dot_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
struct TestTransResult {
TransType trans_type;
RelayoutMode relayout_mod;
megdnn::param::ConvolutionV0::Format conv_format;
};
constexpr size_t pack_c_size = 4_z;
auto test_trans_nchw44_dot =
[](const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> TestTransResult {
TestTransResult ret{TransType::TRANS_NONE, {}};
TestTransResult ret{TransType::TRANS_NONE, {}, {}};
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t IC = filter->shape()[1];
size_t OC = filter->shape()[0];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX;
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} else if (IC < pack_c_size && OC % pack_c_size == 0) {
ret.first = TransType::TRANS_HYBIRD_NCHWXX;
ret.second = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
}
} else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
......@@ -2074,15 +2078,18 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
size_t ocpg = filter->shape()[1];
size_t icpg = filter->shape()[2];
if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
ret.first = TransType::TRANS_NONE;
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX;
ret.second = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
}
}
return ret;
};
auto replace_conv_opr = [test_trans_nchw44_dot, conv_dot_format](
auto replace_conv_opr = [test_trans_nchw44_dot](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -2094,7 +2101,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto is_trans =
test_trans_nchw44_dot(conv_opr.param().sparse, new_inp[1]);
//! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) {
if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
......@@ -2108,14 +2115,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
//! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
auto new_filter = RelayoutPlaceholder::make(new_inp[1],
is_trans.relayout_mod);
conv_filter = new_filter.node();
//! src trans to nchwxx mode
if (new_inp[0]->shape().ndim != 5) {
......@@ -2125,7 +2132,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
conv_src = new_src.node();
}
auto new_param = conv_opr.param();
new_param.format = conv_dot_format;
new_param.format = is_trans.conv_format;
mgb_assert(conv_src->shape().ndim == 5 &&
conv_filter->shape().ndim >= 6,
"The conv src dim is not trans to nchwxx");
......@@ -2137,16 +2144,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
"The conv dst dim is not trans to nchwxx");
return new_opr;
} else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
auto new_filter = RelayoutPlaceholder::make(new_inp[1],
is_trans.relayout_mod);
conv_filter = new_filter.node();
mgb_assert(conv_src->shape().ndim == 4 &&
conv_filter->shape().ndim == 5,
"The src and filter is OK");
auto new_param = conv_opr.param();
new_param.format = conv_dot_format;
new_param.format = is_trans.conv_format;
auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param,
conv_opr.execution_policy(), conv_opr.config());
......@@ -2157,7 +2164,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
}
};
auto replace_conv_bias_opr = [test_trans_nchw44_dot, conv_dot_format](
auto replace_conv_bias_opr = [test_trans_nchw44_dot](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -2168,7 +2175,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto is_trans =
test_trans_nchw44_dot(conv_bias_opr.param().sparse, new_inp[1]);
//! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) {
if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
......@@ -2188,15 +2195,15 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
opr->config());
return new_opr;
} else if (is_trans.first == TransType::TRANS_PURE_NCHWXX) {
} else if (is_trans.trans_type == TransType::TRANS_PURE_NCHWXX) {
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
//! filter trans to nchwxx mode
mgb_assert(new_inp[1]->shape().ndim == 4 ||
new_inp[1]->shape().ndim == 5,
"The origin filter is not NCHW mode");
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
auto new_filter = RelayoutPlaceholder::make(new_inp[1],
is_trans.relayout_mod);
conv_bias_filter = new_filter.node();
//! src trans to nchwxx mode
if (new_inp[0]->shape().ndim != 5) {
......@@ -2213,7 +2220,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
}
auto new_param = conv_bias_opr.param();
new_param.format = conv_dot_format;
new_param.format = is_trans.conv_format;
mgb_assert(conv_bias_src->shape().ndim == 5 &&
conv_bias_filter->shape().ndim >= 6,
"The conv_bias src dim is not trans to nchwxx");
......@@ -2225,11 +2232,11 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
"The conv_bias dst dim is not trans to nchwxx");
return new_opr;
} else {
mgb_assert(is_trans.first == TransType::TRANS_HYBIRD_NCHWXX);
mgb_assert(is_trans.trans_type == TransType::TRANS_HYBIRD_NCHWXX);
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1],
*conv_bias_bias = new_inp[2];
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], is_trans.second);
auto new_filter = RelayoutPlaceholder::make(new_inp[1],
is_trans.relayout_mod);
conv_bias_filter = new_filter.node();
//! bias trans to nchwxx mode, bias may be scale
if (new_inp[2]->shape().ndim == 4) {
......@@ -2242,7 +2249,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert((conv_bias_bias->shape().ndim == 5) ||
conv_bias_bias->shape().is_scalar());
auto new_param = conv_bias_opr.param();
new_param.format = conv_dot_format;
new_param.format = is_trans.conv_format;
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
......
......@@ -2694,7 +2694,7 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册