diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index a43d47cb9c9d0cf84eae215e7879b5748dbbed0d..5909f7ad484ed51c142196afceb62799d4afbda9 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -1,4 +1,3 @@ -#pragma once /** * \file dnn/src/arm_common/conv_bias/intrinsic_helper.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") @@ -10,7 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#include "src/arm_common/conv_bias/neon_struct.h" +#pragma once +#include "src/arm_common/intrinsic_helper.h" +#include "src/arm_common/neon_struct.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" #include "src/fallback/conv_bias/common.h" @@ -689,185 +690,8 @@ inline void init_ocx_ow4(T& c, const int32_t* bias_ptr, int oc_step) { InitOcxOw4::impl(c, bias_ptr, oc_step); } /////////////////////////////////////// -template -struct LoadHelper { - static void impl(T& weight, T2 ptr, int oc_offset, XT... args); -}; - -#define WEIGHT_CB(step) \ - src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); - -template -struct LoadHelper<1, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(1, WEIGHT_CB); - } -}; -template -struct LoadHelper<2, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(2, WEIGHT_CB); - } -}; - -template -struct LoadHelper<3, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(3, WEIGHT_CB); - } -}; -template -struct LoadHelper<4, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(4, WEIGHT_CB); - } -}; -template -struct LoadHelper<5, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(5, WEIGHT_CB); - } -}; -template -struct LoadHelper<6, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(6, WEIGHT_CB); - } -}; -template -struct LoadHelper<7, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(7, WEIGHT_CB); - } -}; -template -struct LoadHelper<8, base_offset, ptr_step, 0, Func, T, T2, XT...> { - static void impl(T& src, T2 ptr, int, XT... args) { - UNROLL_CALL_RAW(8, WEIGHT_CB); - } -}; -#undef WEIGHT_CB - -#define WEIGHT_CB(step) \ - src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); -template -struct LoadHelper<1, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(1, WEIGHT_CB); } -}; -template -struct LoadHelper<2, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(2, WEIGHT_CB); } -}; - -template -struct LoadHelper<3, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(3, WEIGHT_CB); } -}; -template -struct LoadHelper<4, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(4, WEIGHT_CB); } -}; - -template -struct LoadHelper<5, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(5, WEIGHT_CB); } -}; -template -struct LoadHelper<6, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(6, WEIGHT_CB); } -}; - -template -struct LoadHelper<7, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(7, WEIGHT_CB); } -}; - -template -struct LoadHelper<8, base_offset, ptr_step, 1, Func, T, T2> { - static void impl(T& src, T2 ptr, int) { UNROLL_CALL_RAW(8, WEIGHT_CB); } -}; - -#undef WEIGHT_CB - -#define WEIGHT_CB(step) \ - src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ - src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); - -template -struct LoadHelper<1, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(1, WEIGHT_CB); - } -}; -template -struct LoadHelper<2, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(2, WEIGHT_CB); - } -}; - -template -struct LoadHelper<3, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(3, WEIGHT_CB); - } -}; -template -struct LoadHelper<4, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(4, WEIGHT_CB); - } -}; -template -struct LoadHelper<5, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(5, WEIGHT_CB); - } -}; -template -struct LoadHelper<6, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(6, WEIGHT_CB); - } -}; -template -struct LoadHelper<7, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(7, WEIGHT_CB); - } -}; - -template -struct LoadHelper<8, base_offset, ptr_step, 2, Func, T, T2> { - static void impl(T& src, T2 ptr, int oc_offset) { - UNROLL_CALL_RAW(8, WEIGHT_CB); - } -}; - -#undef WEIGHT_CB - -template -inline void load_helper(T& weight, T2 ptr, int oc_offset) { - LoadHelper::impl( - weight, ptr, oc_offset); -} - -template -inline void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { - LoadHelper::impl(weight, ptr, oc_offset, args...); -} } // namespace } // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/intrinsic_helper.h b/dnn/src/arm_common/intrinsic_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..145430c22cad8a25c84eef057cb77b201cf4c73f --- /dev/null +++ b/dnn/src/arm_common/intrinsic_helper.h @@ -0,0 +1,126 @@ +/** + * \file dnn/src/arm_common/intrinsic_helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "src/arm_common/neon_struct.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" +namespace megdnn { +namespace { + +template +struct LoadHelper { + static void impl(T& weight, T2 ptr, int oc_offset, XT... args); +}; + +#define WEIGHT_CB(step) \ + src[step] = Func::impl(ptr + base_offset + step * ptr_step, args...); + +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static void impl(T& src, T2 ptr, int, XT... args) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ + } + +LOAD_HELPER(1); +LOAD_HELPER(2); +LOAD_HELPER(3); +LOAD_HELPER(4); +LOAD_HELPER(5); +LOAD_HELPER(6); +LOAD_HELPER(7); +LOAD_HELPER(8); +LOAD_HELPER(9); +LOAD_HELPER(10); +LOAD_HELPER(11); +LOAD_HELPER(12); +LOAD_HELPER(13); +LOAD_HELPER(14); +LOAD_HELPER(15); +LOAD_HELPER(16); + +#undef LOAD_HELPER +#undef WEIGHT_CB + +///////////////////////////c_dim = 1///////////////////////// +#define WEIGHT_CB(step) \ + src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); + +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static void impl(T& src, T2 ptr, int) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ + } + +LOAD_HELPER(1); +LOAD_HELPER(2); +LOAD_HELPER(3); +LOAD_HELPER(4); +LOAD_HELPER(5); +LOAD_HELPER(6); +LOAD_HELPER(7); +LOAD_HELPER(8); +LOAD_HELPER(9); + +#undef LOAD_HELPER +#undef WEIGHT_CB + +/////////////////////////c_dim = 2/////////////////////////////// +#define WEIGHT_CB(step) \ + src[0][step] = Func::impl(ptr + base_offset + step * ptr_step); \ + src[1][step] = Func::impl(ptr + base_offset + step * ptr_step + oc_offset); + +#define LOAD_HELPER(step) \ + template \ + struct LoadHelper { \ + static void impl(T& src, T2 ptr, int oc_offset) { \ + UNROLL_CALL_RAW(step, WEIGHT_CB); \ + } \ + } + +LOAD_HELPER(1); +LOAD_HELPER(2); +LOAD_HELPER(3); +LOAD_HELPER(4); +LOAD_HELPER(5); +LOAD_HELPER(6); +LOAD_HELPER(7); +LOAD_HELPER(8); + +#undef LOAD_HELPER +#undef WEIGHT_CB + +template +inline void load_helper(T& weight, T2 ptr, int oc_offset) { + LoadHelper::impl( + weight, ptr, oc_offset); +} + +template +inline void load_helper_x(T& weight, T2 ptr, int oc_offset, XT... args) { + LoadHelper::impl(weight, ptr, oc_offset, args...); +} + +} // namespace +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/conv_bias/neon_struct.h b/dnn/src/arm_common/neon_struct.h similarity index 95% rename from dnn/src/arm_common/conv_bias/neon_struct.h rename to dnn/src/arm_common/neon_struct.h index 4303689bbbe739de8ddf20874831279dc4937a24..973fce7b7d41b2e668690bc481409ff8d245b95f 100644 --- a/dnn/src/arm_common/conv_bias/neon_struct.h +++ b/dnn/src/arm_common/neon_struct.h @@ -1,6 +1,5 @@ -#pragma once /** - * \file dnn/src/arm_common/conv_bias/neon_struct.h + * \file dnn/src/arm_common/neon_struct.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. @@ -10,6 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ +#pragma once #include "src/arm_common/simd_macro/marm_neon.h" namespace megdnn { namespace { @@ -62,4 +62,6 @@ struct Vfmaq_laneq_f32 { }; } // namespace -} // namespace megdnn \ No newline at end of file +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/pooling/algo.h b/dnn/src/arm_common/pooling/algo.h index 2fde30355b846545144414f7c6dc2c1b193da133..aea62d625ac56acbb751c96998af894013175f0f 100644 --- a/dnn/src/arm_common/pooling/algo.h +++ b/dnn/src/arm_common/pooling/algo.h @@ -114,7 +114,13 @@ public: bool usable(const PoolingKernSizeParam& param) const override; void exec(const PoolingKernParam& param) const override; }; - +class PoolingImpl::AlgoFp32ModexStridexNCHW44 final : public AlgoBase { +public: + bool is_reproducible() const override { return true; } + const char* name() const override { return "ARM_POOLING_FP32_MODEX_STRIDEX_NCHW44"; } + bool usable(const PoolingKernSizeParam& param) const override; + void exec(const PoolingKernParam& param) const override; +}; WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); WorkspaceBundle get_bundle_nchw44( diff --git a/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp b/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f9b2231e4cc4ba40ed490c424b9e9e84523a526 --- /dev/null +++ b/dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp @@ -0,0 +1,126 @@ +/** + * \file dnn/src/arm_common/pooling/algo_fp32_pooling_nchw44.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megdnn/opr_param_defs.h" +#include "src/arm_common/pooling/algo.h" +#include "src/arm_common/pooling/kern_fp32_pooling_nchw44.h" + +#include "midout.h" + +MIDOUT_DECL(megdnn_arm_common_fp32_pooling_nchw44) + +namespace megdnn { +namespace arm_common { +bool PoolingImpl::AlgoFp32ModexStridexNCHW44::usable( + const PoolingKernSizeParam& param) const { + uint32_t sh = param.stride[0]; + uint32_t sw = param.stride[1]; + uint32_t fh = param.filter[0]; + uint32_t fw = param.filter[1]; + + bool avaible = param.src_type.enumv() == DTypeEnum::Float32 && + param.format == Param::Format::NCHW44 && + (param.mode == Mode::MAX || param.mode == Mode::AVERAGE) && + fh == fw && sh == sw && + (fh == 2 || fh == 3 || fh == 4 || fh == 5) && + (sh == 1 || sh == 2); + return avaible; +} + +void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( + const PoolingKernParam& param) const { + int ih = param.isz[0]; + int iw = param.isz[1]; + int oh = param.osz[0]; + int ow = param.osz[1]; + int n = param.n; + int ic = param.ic; + int ph = param.padding[0]; + int pw = param.padding[1]; + int sh = param.stride[0]; + int fh = param.filter[0]; + + void* src_ptr = param.src_ptr; + void* dst_ptr = param.dst_ptr; + +#define DISPATCH_FUNC(filter, stride, mode) \ + MIDOUT_BEGIN(megdnn_arm_common_fp32_pooling_nchw44, midout_iv(0), \ + midout_iv(#filter #stride #mode##_hash)) { \ + auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, \ + size_t) { \ + const int c_idx = index; \ + pooling_fp32_nchw44( \ + static_cast(src_ptr) + c_idx * ih * iw * 4, \ + static_cast(dst_ptr) + c_idx * oh * ow * 4, ih, \ + iw, oh, ow, ph, pw); \ + }; \ + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ + static_cast<::megdnn::naive::HandleImpl*>(param.handle), \ + n* ic, run); \ + } \ + MIDOUT_END(); + +#define DISPATCH_MODE(filter, stride) \ + switch (param.mode) { \ + case PoolingBase::Mode::MAX: \ + DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \ + break; \ + case PoolingBase::Mode::AVERAGE: \ + DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \ + break; \ + default: \ + megdnn_assert(0, "invalid mode %u", \ + static_cast(param.mode)); \ + } + +#define DISPATCH_STRIDE(filter) \ + switch (sh) { \ + case 1: \ + DISPATCH_MODE(filter, 1); \ + break; \ + case 2: \ + DISPATCH_MODE(filter, 2); \ + break; \ + default: \ + megdnn_assert(0, "invalid stride %d", sh); \ + } + +#define DISPATCH_FILTER() \ + switch (fh) { \ + case 2: \ + DISPATCH_STRIDE(2); \ + break; \ + case 3: \ + DISPATCH_STRIDE(3); \ + break; \ + case 4: \ + DISPATCH_STRIDE(4); \ + break; \ + case 5: \ + DISPATCH_STRIDE(5); \ + break; \ + default: \ + megdnn_assert(0, "invalid filter %d", fh); \ + } + + DISPATCH_FILTER() + +#undef DISPATCH_FILTER +#undef DISPATCH_STRIDE +#undef DISPATCH_MODE +#undef DISPATCH_FUNC +} + +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h new file mode 100644 index 0000000000000000000000000000000000000000..72d97d87d874398388498df25f21daa0db90468f --- /dev/null +++ b/dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h @@ -0,0 +1,308 @@ +/** + * \file dnn/src/arm_common/pooling/kern_fp32_pooling_nchw44.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include +#include "megdnn/opr_param_defs.h" +#include "src/arm_common/intrinsic_helper.h" +#include "src/arm_common/neon_struct.h" +#include "src/arm_common/simd_macro/marm_neon.h" +#include "src/common/unroll_macro.h" + +namespace megdnn { +namespace arm_common { +namespace { + +template +struct CalXsXNchw44 { + static void impl(T1 result, T2 src); +}; + +template +void calculate_xsx_nchw44(T1 result, T2 src) { + CalXsXNchw44::impl(result, src); +}; + +#define CALCULATE_MAX_CB(step) \ + result[0] = vmaxq_f32(result[0], src[0 * stride + step]); \ + result[1] = vmaxq_f32(result[1], src[1 * stride + step]); \ + result[2] = vmaxq_f32(result[2], src[2 * stride + step]); \ + result[3] = vmaxq_f32(result[3], src[3 * stride + step]); + +#define CALCULATE_AVG_CB(step) \ + result[0] = vaddq_f32(result[0], src[0 * stride + step]); \ + result[1] = vaddq_f32(result[1], src[1 * stride + step]); \ + result[2] = vaddq_f32(result[2], src[2 * stride + step]); \ + result[3] = vaddq_f32(result[3], src[3 * stride + step]); + +#define INSTANCE_CAL(filter) \ + template \ + struct CalXsXNchw44 { \ + static void impl(T1 result, T2 src) { \ + UNROLL_CALL_RAW(filter, CALCULATE_MAX_CB); \ + } \ + }; \ + template \ + struct CalXsXNchw44 { \ + static void impl(T1 result, T2 src) { \ + UNROLL_CALL_RAW(filter, CALCULATE_AVG_CB); \ + } \ + }; + +INSTANCE_CAL(2) +INSTANCE_CAL(3) +INSTANCE_CAL(4) +INSTANCE_CAL(5) + +#undef INSTANCE_CAL +#undef CALCULATE_AVG_CB +#undef CALCULATE_MAX_CB + +template +struct KerPoolingFilterXStrideXNchw44 { + static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw); +}; + +template +struct KerPoolingFilterXStrideXNchw44 { + static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) { + constexpr int src_reg_size = ow_step * stride + filter - stride; + constexpr int packed_ic = 4; + constexpr int simd_len = 4; + constexpr float default_float = std::numeric_limits::lowest(); + float32x4_t result[ow_step]; + float32x4_t src[src_reg_size]; + + result[0] = vdupq_n_f32(default_float); + result[1] = vdupq_n_f32(default_float); + result[2] = vdupq_n_f32(default_float); + result[3] = vdupq_n_f32(default_float); + + for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { + load_helper( + src, src_ptr + fh_idx * iw * packed_ic, 0); + calculate_xsx_nchw44(result, src); + } + + vst1q_f32(dst_ptr + 0 * packed_ic, result[0]); + vst1q_f32(dst_ptr + 1 * packed_ic, result[1]); + vst1q_f32(dst_ptr + 2 * packed_ic, result[2]); + vst1q_f32(dst_ptr + 3 * packed_ic, result[3]); + } +}; + +template +struct KerPoolingFilterXStrideXNchw44 { + static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) { + constexpr int src_reg_size = ow_step * stride + filter - stride; + constexpr int packed_ic = 4; + constexpr int simd_len = 4; + constexpr float default_float = 0; + constexpr float div_filter_size = 1.f / (filter * filter); + const float32x4_t div_filter_size_vec = vdupq_n_f32(div_filter_size); + float32x4_t result[ow_step]; + float32x4_t src[src_reg_size]; + + result[0] = vdupq_n_f32(default_float); + result[1] = vdupq_n_f32(default_float); + result[2] = vdupq_n_f32(default_float); + result[3] = vdupq_n_f32(default_float); + + for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { + load_helper( + src, src_ptr + fh_idx * iw * packed_ic, 0); + calculate_xsx_nchw44(result, src); + } + result[0] = vmulq_f32(result[0], div_filter_size_vec); + result[1] = vmulq_f32(result[1], div_filter_size_vec); + result[2] = vmulq_f32(result[2], div_filter_size_vec); + result[3] = vmulq_f32(result[3], div_filter_size_vec); + vst1q_f32(dst_ptr + 0 * packed_ic, result[0]); + vst1q_f32(dst_ptr + 1 * packed_ic, result[1]); + vst1q_f32(dst_ptr + 2 * packed_ic, result[2]); + vst1q_f32(dst_ptr + 3 * packed_ic, result[3]); + } +}; + +template +void ker_pooling_nchw44_remain_pad(const float32_t* src_ptr, float32_t* dst_ptr, + const int iw, const int pad_top, + const int pad_bottom, const int pad_left, + const int pad_right, const int filter); +template <> +void ker_pooling_nchw44_remain_pad( + const float32_t* src_ptr, float32_t* dst_ptr, const int iw, + const int pad_top, const int pad_bottom, const int pad_left, + const int pad_right, const int filter) { + constexpr int ic_step = 4; + const int ih_end = filter - pad_bottom; + const int iw_end = filter - pad_right; + float32x4_t result = vdupq_n_f32(std::numeric_limits::lowest()); + for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) { + for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) { + float32x4_t src = + vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); + result = vmaxq_f32(result, src); + } + src_ptr += iw * ic_step; + } + vst1q_f32(dst_ptr, result); +} + +template <> +void ker_pooling_nchw44_remain_pad( + const float32_t* src_ptr, float32_t* dst_ptr, const int iw, + const int pad_top, const int pad_bottom, const int pad_left, + const int pad_right, const int filter) { + constexpr int ic_step = 4; + const int ih_end = filter - pad_bottom; + const int iw_end = filter - pad_right; + const float div_filter_size = 1.f / (filter * filter); + const float32x4_t div_filter_size_vec = vdupq_n_f32(div_filter_size); + float32x4_t result = vdupq_n_f32(0.f); + + for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) { + for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) { + float32x4_t src = + vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step); + result = vaddq_f32(result, src); + } + src_ptr += iw * ic_step; + } + result = vmulq_f32(result, div_filter_size_vec); + vst1q_f32(dst_ptr, result); +} + +template +static inline void kern_pooling_with_pad_nchw44( + const float32_t* src, float32_t* dst, const int filter, + const int ow_start, const int ow_end, const int iw, const int ow, + const int stride_w, const int pw, const int real_ih_idx, + const int oh_idx, const int pad_top, const int pad_bottom) { + constexpr int ic_step = 4; + constexpr int oc_step = 4; + for (int ow_idx = ow_start; ow_idx < ow_end; ++ow_idx) { + const int iw_idx = ow_idx * stride_w; + const int real_iw_idx = std::max(iw_idx - pw, 0); + const int pad_left = std::max(0, pw - iw_idx); + const int pad_right = std::max(0, iw_idx - pw + filter - iw); + const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step; + const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; + ker_pooling_nchw44_remain_pad(src + src_offset, dst + dst_offset, + iw, pad_top, pad_bottom, pad_left, + pad_right, filter); + } +} + +template +static inline void pooling_fp32_nchw44_pad(const float32_t* src, float32_t* dst, + int ih, int iw, int oh, int ow, + int ph, int pw) { + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int ow_step = 4; + const int ow_pad_left_end = div_ceil(pw, stride_w); + const int ow_pad_right_end = (iw - filter + pw - 1) / stride_w; + const int ow_pad_right_step_end = + (ow_pad_right_end - ow_pad_left_end) / ow_step * ow_step + + ow_pad_left_end; + + rep(oh_idx, oh) { + const int ih_idx = oh_idx * stride_h; + const int real_ih_idx = std::max(ih_idx - ph, 0); + const int pad_top = std::max(0, ph - ih_idx); + const int pad_bottom = std::max(0, ih_idx - ph + filter - ih); + if (pad_top > 0 || pad_bottom > 0) { + kern_pooling_with_pad_nchw44(src, dst, filter, 0, ow, iw, ow, + stride_w, pw, real_ih_idx, + oh_idx, pad_top, pad_bottom); + + } else { + kern_pooling_with_pad_nchw44( + src, dst, filter, 0, ow_pad_left_end, iw, ow, stride_w, pw, + real_ih_idx, oh_idx, pad_top, pad_bottom); + for (int ow_idx = ow_pad_left_end; ow_idx < ow_pad_right_step_end; + ow_idx += ow_step) { + const int iw_idx = ow_idx * stride_w; + const int real_iw_idx = std::max(iw_idx - pw, 0); + const int src_offset = + (real_ih_idx * iw + real_iw_idx) * ic_step; + const int dst_offset = (oh_idx * ow + ow_idx) * oc_step; + KerPoolingFilterXStrideXNchw44::impl(src + src_offset, + dst + dst_offset, + iw); + } + kern_pooling_with_pad_nchw44( + src, dst, filter, ow_pad_right_step_end, ow, iw, ow, + stride_w, pw, real_ih_idx, oh_idx, pad_top, pad_bottom); + } + } +} + +template +static inline void pooling_fp32_nchw44_no_pad(const float32_t* src, + float32_t* dst, int, int iw, + int oh, int ow) { + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int ic_step = 4; + constexpr int oc_step = 4; + constexpr int ow_step = 4; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + + rep(oh_idx, oh) { + const int ih_idx = oh_idx * stride_h; + const int src_ih_offset = ih_idx * iw; + const int dst_oh_offset = oh_idx * ow; + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int iw_idx = ow_idx * stride_w; + const int src_offset = (src_ih_offset + iw_idx) * ic_step; + const int dst_offset = (dst_oh_offset + ow_idx) * oc_step; + KerPoolingFilterXStrideXNchw44::impl( + src + src_offset, dst + dst_offset, iw); + } + if (ow_remain > 0) { + kern_pooling_with_pad_nchw44(src, dst, filter, ow_end, ow, iw, + ow, stride_w, 0, ih_idx, oh_idx, + 0, 0); + } + } +} + +template +static inline void pooling_fp32_nchw44(const float32_t* src, float32_t* dst, + int ih, int iw, int oh, int ow, int ph, + int pw) { + if (ph > 0 || pw > 0) { + pooling_fp32_nchw44_pad(src, dst, ih, iw, oh, ow, + ph, pw); + } else { + pooling_fp32_nchw44_no_pad(src, dst, ih, iw, oh, + ow); + } +} + +} // namespace +} // namespace arm_common +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/arm_common/pooling/opr_impl.cpp b/dnn/src/arm_common/pooling/opr_impl.cpp index f7165cd31bdcfc1c9275bb1033b017f2d99d6f64..46da10c8ab9c1161d39dfb5c598e236741bd6ba0 100644 --- a/dnn/src/arm_common/pooling/opr_impl.cpp +++ b/dnn/src/arm_common/pooling/opr_impl.cpp @@ -29,6 +29,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4; AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; + AlgoFp32ModexStridexNCHW44 algo_fp32_modex_stridex_nchw44; public: AlgoPack() { @@ -44,6 +45,7 @@ public: all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); + all_algos.emplace_back(&algo_fp32_modex_stridex_nchw44); } SmallVector all_algos; }; diff --git a/dnn/src/arm_common/pooling/opr_impl.h b/dnn/src/arm_common/pooling/opr_impl.h index 92b6b46419afd9f9eb67880de3762ddc4bf23d7a..c3a5335ec227bf611f681df856ed84d7e45381e7 100644 --- a/dnn/src/arm_common/pooling/opr_impl.h +++ b/dnn/src/arm_common/pooling/opr_impl.h @@ -87,10 +87,10 @@ private: class AlgoFilter3ModexStridexNCHW44; class AlgoFilter4ModexStridexNCHW44; class AlgoFilter5ModexStridexNCHW44; + class AlgoFp32ModexStridexNCHW44; class AlgoPack; }; } // namespace arm_common } // namespace megdnn // vim: syntax=cpp.doxygen - diff --git a/dnn/src/common/unroll_macro.h b/dnn/src/common/unroll_macro.h index 1e6549bf0fd4a52f7255d6bab289aff7181baed4..d0d0f1c61130e11d436e9878816f7b00942fdcc6 100644 --- a/dnn/src/common/unroll_macro.h +++ b/dnn/src/common/unroll_macro.h @@ -6,7 +6,8 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #pragma once @@ -32,10 +33,33 @@ #define UNROLL_RAW9(cb, v0, a...) \ UNROLL_RAW8(cb, v0, ##a) \ cb(8, ##a) +#define UNROLL_RAW10(cb, v0, a...) \ + UNROLL_RAW9(cb, v0, ##a) \ + cb(9, ##a) +#define UNROLL_RAW11(cb, v0, a...) \ + UNROLL_RAW10(cb, v0, ##a) \ + cb(10, ##a) +#define UNROLL_RAW12(cb, v0, a...) \ + UNROLL_RAW11(cb, v0, ##a) \ + cb(11, ##a) +#define UNROLL_RAW13(cb, v0, a...) \ + UNROLL_RAW12(cb, v0, ##a) \ + cb(12, ##a) +#define UNROLL_RAW14(cb, v0, a...) \ + UNROLL_RAW13(cb, v0, ##a) \ + cb(13, ##a) +#define UNROLL_RAW15(cb, v0, a...) \ + UNROLL_RAW14(cb, v0, ##a) \ + cb(14, ##a) + +// clang-format off #define UNROLL_RAW16(cb, v0, a...) \ UNROLL_RAW8(cb, v0, ##a) \ cb(8, ##a) cb(9, ##a) cb(10, ##a) cb(11, ##a) cb(12, ##a) cb(13, ##a) \ cb(14, ##a) cb(15, ##a) +#define UNROLL_RAW17(cb, v0, a...) \ + UNROLL_RAW16(cb, v0, ##a) \ + cb(16, ##a) #define UNROLL_RAW24(cb, v0, a...) \ UNROLL_RAW16(cb, v0, ##a) \ cb(16, ##a) cb(17, ##a) cb(18, ##a) cb(19, ##a) cb(20, ##a) cb(21, ##a) \ @@ -130,4 +154,6 @@ #define UNROLL_CALL_NOWRAPPER_D2(step, step2, cb) \ UNROLL_CALL_RAW_D2(step, step2, cb) +// clang-format on + // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/pooling.cpp b/dnn/test/arm_common/pooling.cpp index 3b9bf57662014a8412e2eb2a9ce46802ff862904..e1210eeaa3363edae6ecd5efed208e4a104ddfaf 100644 --- a/dnn/test/arm_common/pooling.cpp +++ b/dnn/test/arm_common/pooling.cpp @@ -256,6 +256,73 @@ TEST_F(ARM_COMMON, POOLING_QUANTIZED) { } #if MEGDNN_WITH_BENCHMARK + +void benchmark_nchw44_fp32(Handle *handle) { + using Param = param::Pooling; + auto run = [&](size_t n, size_t c, size_t h, size_t w, size_t filter, + size_t stride, size_t pad, Param::Mode mode) { + Param param; + param.window_h = param.window_w = filter; + param.stride_h = param.stride_w = stride; + param.pad_h = param.pad_w = pad; + param.format = Param::Format::NCHW; + param.mode = mode; + TensorShape nchw_shape = {n, c, h, w}; + TensorShape nchw44_shape = {n, c / 4, h, w, 4}; + TensorLayout dst_layout; + auto opr = handle->create_operator(); + opr->param() = param; + opr->deduce_layout({nchw_shape, dtype::Float32()}, dst_layout); + float calc_amount = + dst_layout.total_nr_elems() * param.window_h * param.window_w; + + Benchmarker benchmarker_float_nchw(handle); + Benchmarker benchmarker_float_nchw44(handle); + Benchmarker benchmarker_int_nchw44(handle); + size_t RUN = 500; + auto t1 = benchmarker_float_nchw.set_display(false) + .set_times(RUN) + .set_param(param) + .exec({nchw_shape, {}}); + + param.format = Param::Format::NCHW44; + auto t2 = benchmarker_int_nchw44.set_display(false) + .set_times(RUN) + .set_param(param) + .execl({{nchw44_shape, dtype::QuantizedS8(1.0)}, + {{}, dtype::QuantizedS8(1.0)}}); + auto t3 = benchmarker_float_nchw44.set_display(false) + .set_times(RUN) + .set_param(param) + .exec({nchw44_shape, {}}); + + printf("{%zu %zu %zu %zu} filter = %zu, stride = %zu pad = %zu\n" + "nchw_fp32={%.3f ms, %.3f Mflops}, " + "nchw44_int={%.3f ms, %.3f Mflops}, " + "nchw44_fp32={%.3f ms, %.3f Mflops, speed_up %f}\n\n", + n, c, h, w, filter, stride, pad, t1 / RUN, + calc_amount / (t1 / RUN * 1000), t2 / RUN, + calc_amount / (t2 / RUN * 1000), t3 / RUN, + calc_amount / (t3 / RUN * 1000), t1 / t3); + }; + // Resnet50 + run(1, 64, 112, 112, 3, 2, 1, param::Pooling::Mode::MAX); + run(1, 2048, 7, 7, 7, 1, 0, param::Pooling::Mode::AVERAGE); + + // VGG16 + run(1, 64, 224, 224, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 128, 112, 112, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 256, 56, 56, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 512, 28, 28, 2, 2, 0, param::Pooling::Mode::MAX); + run(1, 512, 14, 14, 2, 2, 0, param::Pooling::Mode::MAX); +} + +TEST_F(ARM_COMMON, BENCHMARK_POOLING_NCHW44_FP32) { benchmark_nchw44_fp32(handle()); } + +TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW44_FP32) { + benchmark_nchw44_fp32(handle()); +} + TEST_F(ARM_COMMON, BENCHMARK_POOLING_INT8_W3x3_S2x2) { using Param = param::Pooling; diff --git a/dnn/test/arm_common/pooling_multi_thread.cpp b/dnn/test/arm_common/pooling_multi_thread.cpp index 3f3cfe2e40f59b7dc7cfe9b8bb000157f78469c8..aa73370c392b857c75cf5635712a17f9cf0f1ff6 100644 --- a/dnn/test/arm_common/pooling_multi_thread.cpp +++ b/dnn/test/arm_common/pooling_multi_thread.cpp @@ -57,6 +57,65 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { } } +std::vector> get_nchw44_pool_args( + size_t filter, size_t stride) { + constexpr size_t ic_step = 4; + std::vector> args; + + for (size_t n : {1, 2}) + for (size_t c : {4, 8}) + for (size_t ih : {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}) + for (size_t iw : {3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}) + for (size_t ph : {0, 1, 2}) + for (size_t pw : {0, 1, 2}) + for (auto mode : {param::Pooling::Mode::MAX, + param::Pooling::Mode::AVERAGE}) + if (ih + 2 * ph >= filter && + iw + 2 * pw >= filter && filter > ph && + filter > pw) { + param::Pooling param; + param.mode = mode; + param.format = + param::Pooling::Format::NCHW44; + param.pad_h = ph; + param.pad_w = pw; + param.stride_h = param.stride_w = stride; + param.window_h = param.window_w = filter; + args.emplace_back(std::make_pair( + param, + TensorShapeArray{{n, c / ic_step, + ih, iw, ic_step}, + {}})); + } + return args; +} + +void run_pooling_check( + Handle* handle, + std::vector> args, + bool is_int8) { + Checker checker(handle); + UniformIntRNG rng_int8{INT8_MIN >> 1, INT8_MAX >> 1}; + UniformIntRNG rng_fp32{-10, 10}; + if (is_int8) { + checker.set_dtype(0, dtype::QuantizedS8(1.1f)); + checker.set_rng(0, &rng_int8); + } else { + checker.set_rng(0, &rng_fp32); + } + for (auto arg : args) { + checker.set_param(arg.first).exec(arg.second); + } +} + +TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_NCHW44_FP32) { + for (auto filter : {2, 3, 4, 5}) + for (auto stride : {1, 2}) { + run_pooling_check(handle(), get_nchw44_pool_args(filter, stride), + false); + } +} + TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44) { // clang-format off