提交 02abc36e 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mbg/arm_common): fix nchw44-dot misc issue

GitOrigin-RevId: f870ad964c075fe55fb1c7ca131680873bec61eb
上级 9ed3882a
......@@ -182,7 +182,7 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
bool ok_type = ((param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.filter_type.enumv() == DTypeEnum::QuantizedS8 &&
(param.dst_type.enumv() == DTypeEnum::QuantizedS8))) &&
(fm.format == param::Convolution::Format::NCHW44);
(fm.format == param::Convolution::Format::NCHW44_DOT);
bool ok_src_dst = (oc % 4 == 0 && oc >= 4 && ic < 4);
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 2 || fh == 3 || fh == 5 || fh == 7);
......
......@@ -55,7 +55,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44 s8_channel_wise_stride2_nchw44;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44 ds8_direct_stride2_nchw_nchw44;
AlgoDotS8DirectStride1 ds8_direct_stride1_large_group{true};
AlgoDotS8DirectStride1 ds8_direct_stride1_small_group{false};
AlgoDotS8DirectStride2 ds8_direct_stride2_large_group{true};
......@@ -66,6 +65,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2 du8_direct_stride2_small_group{false};
AlgoDotS8Direct_NCHW44 ds8_direct_nchw44;
AlgoDotS8DirectNCHWNCHW44 ds8_direct_nchw_nchw44;
#endif
AlgoF32DirectNCHWNCHW44 f32_direct_stride2_nchw_nchw44;
......@@ -96,7 +96,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack() {
#if __ARM_FEATURE_DOTPROD
direct_algos.emplace_back(&ds8_direct_stride2_nchw_nchw44);
direct_algos.emplace_back(&ds8_direct_stride1_large_group);
direct_algos.emplace_back(&ds8_direct_stride1_small_group);
direct_algos.emplace_back(&ds8_direct_stride2_large_group);
......@@ -107,6 +106,7 @@ public:
direct_algos.emplace_back(&du8_direct_stride2_small_group);
direct_algos.emplace_back(&ds8_direct_nchw44);
direct_algos.emplace_back(&ds8_direct_nchw_nchw44);
#endif
direct_algos.emplace_back(&qu8_direct_stride2_large_group);
direct_algos.emplace_back(&qu8_direct_stride2_small_group);
......
......@@ -582,14 +582,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
checker_conv_bias_qint8x8x8(
get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true),
handle(), "ARMDOTS8_NCHW_NCHW44");
auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, false,
true);
for (auto&& arg : args) {
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
args = get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, false,
true);
for (auto&& arg : args) {
arg.param.format = param::ConvBias::Format::NCHW44_DOT;
}
checker_conv_bias_qint8x8x8(args, handle(), "ARMDOTS8_NCHW_NCHW44");
}
TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP) {
......
......@@ -987,6 +987,12 @@ Args Args::from_argv(int argc, char **argv) {
cb(nchw32);
cb(nhwcd4);
#undef cb
if (!strcmp(argv[i], "--enable-nchw44-dot")) {
mgb_log_warn("enable-nchw44-dot optimization");
graph_opt.graph_opt.enable_nchw44_dot();
continue;
}
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();
......
......@@ -94,7 +94,6 @@ namespace mgb {
m_param{param}, m_param_size{param_size}
{
}
//! build a blob representation to be used as cache key
PersistentCache::Blob build_blob() const;
};
......
......@@ -611,7 +611,6 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(),
origin_layouts.size(), &origin_param,
sizeof(origin_param)};
{
auto&& rst = cache.get(cache_key);
if (rst.valid())
......
......@@ -107,7 +107,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape[1] / group * 2;
return hybird_nchwx ? computation : computation * 8;
}
if (param.format == Param::Format::NCHW44) {
if (param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT) {
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if (filter_shape[1] == 1 && filter_shape[2] == 1) {
group *= 4;
......@@ -145,6 +146,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
if (param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW32) {
return eval_conv_computation_nchwx();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册