From 19d36fa03c477f6f6c926304015264479e5889aa Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 14 Jun 2022 18:15:07 +0800 Subject: [PATCH] feat(gi): make pooling apply gi class type GitOrigin-RevId: e60c6a2e7640cdfa75c1e7ad89f2ed78865847f5 --- .../gi/do_max_pooling_3x3_s2x2_float.cpp | 10 +- .../pooling/gi/kern_fp32_pooling_nchw44.h | 94 ++++++++++++------- dnn/src/fallback/pooling/gi/pooling_helper.h | 46 +++++---- 3 files changed, 91 insertions(+), 59 deletions(-) diff --git a/dnn/src/fallback/pooling/gi/do_max_pooling_3x3_s2x2_float.cpp b/dnn/src/fallback/pooling/gi/do_max_pooling_3x3_s2x2_float.cpp index 5c3b666a..40244395 100644 --- a/dnn/src/fallback/pooling/gi/do_max_pooling_3x3_s2x2_float.cpp +++ b/dnn/src/fallback/pooling/gi/do_max_pooling_3x3_s2x2_float.cpp @@ -8,11 +8,11 @@ namespace megdnn { namespace fallback { -#define GI_UZP(s0, s1, d0, d1) \ - do { \ - auto tmp__ = GiUzpqFloat32(s0, s1); \ - d0 = tmp__.val[0]; \ - d1 = tmp__.val[1]; \ +#define GI_UZP(s0, s1, d0, d1) \ + do { \ + auto tmp__ = GiUzpqFloat32(s0, s1); \ + d0 = GiGetSubVectorFloat32V2(tmp__, 0); \ + d1 = GiGetSubVectorFloat32V2(tmp__, 1); \ } while (0) void do_max_pooling_3x3_s2x2_float_gi( diff --git a/dnn/src/fallback/pooling/gi/kern_fp32_pooling_nchw44.h b/dnn/src/fallback/pooling/gi/kern_fp32_pooling_nchw44.h index 4bf4d516..472aebe9 100644 --- a/dnn/src/fallback/pooling/gi/kern_fp32_pooling_nchw44.h +++ b/dnn/src/fallback/pooling/gi/kern_fp32_pooling_nchw44.h @@ -29,17 +29,33 @@ void calculate_xsx_nchw44(T1 result, T2 src) { CalXsXNchw44::impl(result, src); }; -#define CALCULATE_MAX_CB(step) \ - result[0] = GiMaximumFloat32(result[0], src[0 * stride + step]); \ - result[1] = GiMaximumFloat32(result[1], src[1 * stride + step]); \ - result[2] = GiMaximumFloat32(result[2], src[2 * stride + step]); \ - result[3] = GiMaximumFloat32(result[3], src[3 * stride + step]); +#define CALCULATE_MAX_CB(step) \ + result[0] = GiFloat32Type2FixLenType(GiMaximumFloat32( \ + GiFixLenType2GiFloat32Type(result[0]), \ + GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \ + result[1] = GiFloat32Type2FixLenType(GiMaximumFloat32( \ + GiFixLenType2GiFloat32Type(result[1]), \ + GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \ + result[2] = GiFloat32Type2FixLenType(GiMaximumFloat32( \ + GiFixLenType2GiFloat32Type(result[2]), \ + GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \ + result[3] = GiFloat32Type2FixLenType(GiMaximumFloat32( \ + GiFixLenType2GiFloat32Type(result[3]), \ + GiFixLenType2GiFloat32Type(src[3 * stride + step]))); -#define CALCULATE_AVG_CB(step) \ - result[0] = GiAddFloat32(result[0], src[0 * stride + step]); \ - result[1] = GiAddFloat32(result[1], src[1 * stride + step]); \ - result[2] = GiAddFloat32(result[2], src[2 * stride + step]); \ - result[3] = GiAddFloat32(result[3], src[3 * stride + step]); +#define CALCULATE_AVG_CB(step) \ + result[0] = GiFloat32Type2FixLenType(GiAddFloat32( \ + GiFixLenType2GiFloat32Type(result[0]), \ + GiFixLenType2GiFloat32Type(src[0 * stride + step]))); \ + result[1] = GiFloat32Type2FixLenType(GiAddFloat32( \ + GiFixLenType2GiFloat32Type(result[1]), \ + GiFixLenType2GiFloat32Type(src[1 * stride + step]))); \ + result[2] = GiFloat32Type2FixLenType(GiAddFloat32( \ + GiFixLenType2GiFloat32Type(result[2]), \ + GiFixLenType2GiFloat32Type(src[2 * stride + step]))); \ + result[3] = GiFloat32Type2FixLenType(GiAddFloat32( \ + GiFixLenType2GiFloat32Type(result[3]), \ + GiFixLenType2GiFloat32Type(src[3 * stride + step]))); #define INSTANCE_CAL(filter) \ template \ @@ -78,13 +94,13 @@ struct KerPoolingFilterXStrideXNchw44::lowest(); - GI_FLOAT32_t result[ow_step]; - GI_FLOAT32_t src[src_reg_size]; + GI_FLOAT32_FIXLEN_t result[ow_step]; + GI_FLOAT32_FIXLEN_t src[src_reg_size]; - result[0] = GiBroadcastFloat32(default_float); - result[1] = GiBroadcastFloat32(default_float); - result[2] = GiBroadcastFloat32(default_float); - result[3] = GiBroadcastFloat32(default_float); + result[0] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float)); + result[1] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float)); + result[2] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float)); + result[3] = GiFloat32Type2FixLenType(GiBroadcastFloat32(default_float)); for (int fh_idx = 0; fh_idx < filter; ++fh_idx) { load_helper( @@ -93,10 +109,10 @@ struct KerPoolingFilterXStrideXNchw44( src, src_ptr + fh_idx * iw * packed_ic, 0); calculate_xsx_nchw44( result, src); - } - result[0] = GiMultiplyFloat32(result[0], div_filter_size_vec); - result[1] = GiMultiplyFloat32(result[1], div_filter_size_vec); - result[2] = GiMultiplyFloat32(result[2], div_filter_size_vec); - result[3] = GiMultiplyFloat32(result[3], div_filter_size_vec); - GiStoreFloat32(dst_ptr + 0 * packed_ic, result[0]); - GiStoreFloat32(dst_ptr + 1 * packed_ic, result[1]); - GiStoreFloat32(dst_ptr + 2 * packed_ic, result[2]); - GiStoreFloat32(dst_ptr + 3 * packed_ic, result[3]); + }; + GiStoreFloat32( + dst_ptr + 0 * packed_ic, + GiMultiplyFloat32( + GiFixLenType2GiFloat32Type(result[0]), div_filter_size_vec)); + GiStoreFloat32( + dst_ptr + 1 * packed_ic, + GiMultiplyFloat32( + GiFixLenType2GiFloat32Type(result[1]), div_filter_size_vec)); + GiStoreFloat32( + dst_ptr + 2 * packed_ic, + GiMultiplyFloat32( + GiFixLenType2GiFloat32Type(result[2]), div_filter_size_vec)); + GiStoreFloat32( + dst_ptr + 3 * packed_ic, + GiMultiplyFloat32( + GiFixLenType2GiFloat32Type(result[3]), div_filter_size_vec)); } }; diff --git a/dnn/src/fallback/pooling/gi/pooling_helper.h b/dnn/src/fallback/pooling/gi/pooling_helper.h index 3bece705..03eb330f 100644 --- a/dnn/src/fallback/pooling/gi/pooling_helper.h +++ b/dnn/src/fallback/pooling/gi/pooling_helper.h @@ -56,18 +56,20 @@ struct GiMeanPooler { static constexpr int MIDOUT_CASE_NUM = 1; static constexpr int SIMD_WIDTH = 4; - static const GI_FLOAT32_t coef; - GI_FLOAT32_t res; - GiMeanPooler(DType) : res(GiBroadcastFloat32(0.0f)) {} - void feed(const float* val) { res = GiAddFloat32(res, GiLoadFloat32(val)); } + GI_FLOAT32_FIXLEN_t res, coef; + GiMeanPooler(DType) + : res(GiFloat32Type2FixLenType(GiBroadcastFloat32(0.0f))), + coef(GiFloat32Type2FixLenType(GiBroadcastFloat32(1.0f / area))) {} + void feed(const float* val) { + res = GiFloat32Type2FixLenType( + GiAddFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val))); + } void post(float* dst) { - res = GiMultiplyFloat32(res, coef); - GiStoreFloat32(dst, res); + res = GiFloat32Type2FixLenType(GiMultiplyFloat32( + GiFixLenType2GiFloat32Type(res), GiFixLenType2GiFloat32Type(coef))); + GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res)); } }; -template -const GI_FLOAT32_t GiMeanPooler::coef = - GiBroadcastFloat32(1.0f / area); /* ======================= MaxPooler ======================== */ @@ -96,10 +98,15 @@ struct GiMaxPooler { static constexpr int MIDOUT_CASE_NUM = 11; static constexpr int SIMD_WIDTH = 4; - GI_FLOAT32_t res; - GiMaxPooler(DType) : res(GiBroadcastFloat32(DTypeTrait::min())) {} - void feed(const float* val) { res = GiMaximumFloat32(res, GiLoadFloat32(val)); } - void post(float* dst) { GiStoreFloat32(dst, res); } + GI_FLOAT32_FIXLEN_t res; + GiMaxPooler(DType) + : res(GiFloat32Type2FixLenType( + GiBroadcastFloat32(DTypeTrait::min()))) {} + void feed(const float* val) { + res = GiFloat32Type2FixLenType( + GiMaximumFloat32(GiFixLenType2GiFloat32Type(res), GiLoadFloat32(val))); + } + void post(float* dst) { GiStoreFloat32(dst, GiFixLenType2GiFloat32Type(res)); } }; template @@ -137,7 +144,8 @@ struct do_pxl_2x2_pack_proxy< const int IW, const int OH, const int OW, const int PH, const int PW) { MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(OH); - static const auto avg_coef = GiBroadcastFloat32(0.25f); + static const auto avg_coef = + GiFloat32Type2FixLenType(GiBroadcastFloat32(0.25f)); int ih = -PH + 2 * oh; int iw = -PW + 2 * ow; auto i00 = GiLoadFloat32(src + (ih + 0) * IW + (iw + 0)), @@ -148,7 +156,7 @@ struct do_pxl_2x2_pack_proxy< auto vlow = GiPaddFloat32(GiGetLowFloat32(sum0), GiGetHighFloat32(sum0)); auto vhigh = GiPaddFloat32(GiGetLowFloat32(sum1), GiGetHighFloat32(sum1)); auto comb = GiCombineFloat32(vlow, vhigh); - auto result = GiMultiplyFloat32(comb, avg_coef); + auto result = GiMultiplyFloat32(comb, GiFixLenType2GiFloat32Type(avg_coef)); GiStoreFloat32(dst + oh * OW + ow, result); } }; @@ -327,8 +335,8 @@ void do_max_pooling_w5x5_s2x2_gi( auto s0 = GiLoadFloat32(sptr + iw + 0); auto s1 = GiLoadFloat32(sptr + iw + MEGDNN_SIMD_WIDTH); auto d = GiUzpqFloat32(s0, s1); - GiStoreFloat32(even + even_offset, d.val[0]); - GiStoreFloat32(odd + odd_offset, d.val[1]); + GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(d, 0)); + GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(d, 1)); even_offset += MEGDNN_SIMD_WIDTH; odd_offset += MEGDNN_SIMD_WIDTH; } @@ -464,8 +472,8 @@ void do_average_pooling_3x3_s2x2_gi( for (; iw + 2 * MEGDNN_SIMD_WIDTH <= IW; iw += 2 * MEGDNN_SIMD_WIDTH) { auto s0 = GiLd2qFloat32(sptr + iw); - GiStoreFloat32(even + even_offset, s0.val[0]); - GiStoreFloat32(odd + odd_offset, s0.val[1]); + GiStoreFloat32(even + even_offset, GiGetSubVectorFloat32V2(s0, 0)); + GiStoreFloat32(odd + odd_offset, GiGetSubVectorFloat32V2(s0, 1)); even_offset += MEGDNN_SIMD_WIDTH; odd_offset += MEGDNN_SIMD_WIDTH; } -- GitLab