diff --git a/dnn/src/arm_common/pooling/algo.cpp b/dnn/src/arm_common/pooling/algo.cpp index 286c32f5eb9dabe2be4140609b2c7116811715fb..980ee905484f60c62c3a72c579d241ebdd12a770 100644 --- a/dnn/src/arm_common/pooling/algo.cpp +++ b/dnn/src/arm_common/pooling/algo.cpp @@ -93,6 +93,7 @@ const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW, } return need_pad ? sptr_base : src; } + bool PoolingImpl::AlgoFilterxModexStride1::usable( const PoolingKernSizeParam& param) const { auto SH = param.stride[0]; @@ -104,7 +105,8 @@ bool PoolingImpl::AlgoFilterxModexStride1::usable( param.src_type.category() == DTypeCategory::QUANTIZED) && param.format == Param::Format::NCHW && SH == 1 && SW == 1 && FH == FW && (FH == 2 || FH == 3); - return avaible; + bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE); + return avaible && is_mode_ok; } void PoolingImpl::AlgoFilterxModexStride1::exec( @@ -202,7 +204,8 @@ bool PoolingImpl::AlgoFilter2ModexStride2::usable( param.src_type.category() == DTypeCategory::QUANTIZED) && param.format == Param::Format::NCHW && FH == FW && SH == SW && FH == 2 && SH == 2; - return avaible; + bool is_mode_ok = (param.mode == Mode::MAX || param.mode == Mode::AVERAGE); + return avaible && is_mode_ok; } void PoolingImpl::AlgoFilter2ModexStride2::exec( diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp index 430264f09cc11202007845fe24205a995601748f..c6f489371e235552b264fb3c2d7cc467a8ef1421 100644 --- a/dnn/test/arm_common/pooling.cpp +++ b/dnn/test/arm_common/pooling.cpp @@ -53,6 +53,18 @@ TEST_F(ARM_COMMON, POOLING) if (ih + p * 2 >= 5 && iw + p * 2 >= 5) checker.set_param(param).exec({{2, 3, ih, iw}, {}}); } + for (size_t ih: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t iw: {2, 3, 5, 7, 11, 13, 17, 19, 23, 24, 25, 26, 27, 28, 29, 30}) + for (size_t p: {1, 2}) + { + Param param; + param.mode = Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING; + param.window_h = param.window_w = 3; + param.stride_h = param.stride_w = 1; + param.pad_h = param.pad_w = p; + Checker checker(handle()); + checker.set_param(param).exec({{2, 3, ih, iw}, {}}); + } // clang-format on }