From c90e0b54bea08b46b656da8f69aafac353e28279 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 27 Dec 2021 15:39:58 +0800 Subject: [PATCH] perf(arm): optimize arm uint16 relayout with n=4 GitOrigin-RevId: 5779c6b9c15aa52447e32f8d95d1b845c6d21e18 --- dnn/src/aarch64/relayout/opr_impl.cpp | 69 +++++++++++++++++++++++++++ dnn/test/aarch64/relayout.cpp | 12 +++++ 2 files changed, 81 insertions(+) diff --git a/dnn/src/aarch64/relayout/opr_impl.cpp b/dnn/src/aarch64/relayout/opr_impl.cpp index 6827f4fad..dfcf5036f 100644 --- a/dnn/src/aarch64/relayout/opr_impl.cpp +++ b/dnn/src/aarch64/relayout/opr_impl.cpp @@ -305,6 +305,64 @@ static inline void trans_8x8_u16( vst1q_u16(dst_ptr + 7 * dst_step, row_7); } +static inline void trans_8x4_u16( + const void* src, void* dst, const size_t src_step, const size_t dst_step) { + uint16_t* src_ptr = (uint16_t*)src; + uint16_t* dst_ptr = (uint16_t*)dst; + uint16x4_t src0 = vld1_u16(src_ptr + 0 * src_step); // A0A1A2A3 + uint16x4_t src1 = vld1_u16(src_ptr + 1 * src_step); // B0B1B2B3 + uint16x4_t src2 = vld1_u16(src_ptr + 2 * src_step); // C0C1C2C3 + uint16x4_t src3 = vld1_u16(src_ptr + 3 * src_step); // D0D1D2D3 + uint16x4_t src4 = vld1_u16(src_ptr + 4 * src_step); // E0E1E2E3 + uint16x4_t src5 = vld1_u16(src_ptr + 5 * src_step); // F0F1F2F3 + uint16x4_t src6 = vld1_u16(src_ptr + 6 * src_step); // G0G1G2G3 + uint16x4_t src7 = vld1_u16(src_ptr + 7 * src_step); // H0H1H2H3 + + uint16x4_t ab_low = vzip1_u16(src0, src1); // A0B0A1B1 + uint16x4_t ab_high = vzip2_u16(src0, src1); // A2B2A3B3 + uint16x4_t cd_low = vzip1_u16(src2, src3); // C0D0C1D1 + uint16x4_t cd_high = vzip2_u16(src2, src3); // C2D2C3D3 + uint16x4_t ef_low = vzip1_u16(src4, src5); // E0F0E1F1 + uint16x4_t ef_high = vzip2_u16(src4, src5); // E2F2E3F3 + uint16x4_t gh_low = vzip1_u16(src6, src7); // G0H0G1H1 + uint16x4_t gh_high = vzip2_u16(src6, src7); // G2H2G3H3 + + uint16x4_t abcd_0 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ab_low), + vreinterpret_u32_u16(cd_low))); // A0B0C0D0 + uint16x4_t abcd_1 = vreinterpret_u16_u32(vzip2_u32( + vreinterpret_u32_u16(ab_low), + vreinterpret_u32_u16(cd_low))); // A1B1C1D1 + uint16x4_t abcd_2 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ab_high), + vreinterpret_u32_u16(cd_high))); // A2B2C2D2 + uint16x4_t abcd_3 = vreinterpret_u16_u32(vzip2_u32( + vreinterpret_u32_u16(ab_high), + vreinterpret_u32_u16(cd_high))); // A3B3C3D3 + uint16x4_t efgh_0 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ef_low), + vreinterpret_u32_u16(gh_low))); // E0F0G0H0 + uint16x4_t efgh_1 = vreinterpret_u16_u32(vzip2_u32( + vreinterpret_u32_u16(ef_low), + vreinterpret_u32_u16(gh_low))); // E1F1G1H1 + uint16x4_t efgh_2 = vreinterpret_u16_u32(vzip1_u32( + vreinterpret_u32_u16(ef_high), + vreinterpret_u32_u16(gh_high))); // E2F2G2H2 + uint16x4_t efgh_3 = vreinterpret_u16_u32(vzip2_u32( + vreinterpret_u32_u16(ef_high), + vreinterpret_u32_u16(gh_high))); // E3F3G3H3 + + uint16x8_t row_0 = vcombine_u16(abcd_0, efgh_0); + uint16x8_t row_1 = vcombine_u16(abcd_1, efgh_1); + uint16x8_t row_2 = vcombine_u16(abcd_2, efgh_2); + uint16x8_t row_3 = vcombine_u16(abcd_3, efgh_3); + + vst1q_u16(dst_ptr + 0 * dst_step, row_0); + vst1q_u16(dst_ptr + 1 * dst_step, row_1); + vst1q_u16(dst_ptr + 2 * dst_step, row_2); + vst1q_u16(dst_ptr + 3 * dst_step, row_3); +} + } // anonymous namespace namespace megdnn { @@ -346,6 +404,17 @@ void transpose_block( trans_8x8_u16(src, dst, src_stride, dst_stride); } +template <> +void transpose_block( + const Transpose2Byte* src, Transpose2Byte* dst, const size_t src_stride, + const size_t dst_stride, size_t block_h, size_t block_w) { + if (block_h == 8 && block_w == 4) { + trans_8x4_u16(src, dst, src_stride, dst_stride); + } else { + transpose_block_fallback(src, dst, src_stride, dst_stride, block_h, block_w); + } +} + } // namespace transpose_fallback } // namespace relayout } // namespace megdnn diff --git a/dnn/test/aarch64/relayout.cpp b/dnn/test/aarch64/relayout.cpp index 57dea5f44..3a6045013 100644 --- a/dnn/test/aarch64/relayout.cpp +++ b/dnn/test/aarch64/relayout.cpp @@ -67,6 +67,18 @@ TEST_F(AARCH64, RelayoutBig) { checker.execl({src, dst}); } +TEST_F(AARCH64, RelayoutSplict) { + Checker checker(handle()); + ConsecutiveRNG rng; + checker.set_rng(0, &rng); + int m = 4; + for (int n : {4, 28}) { + TensorLayout src({(size_t)m, (size_t)n}, {1, m}, dtype::Uint16()); + TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype::Uint16()); + checker.execl({src, dst}); + } +} + TEST_F(AARCH64, RelayoutRecord) { TaskRecordChecker checker(0); std::vector<::megdnn::DType> dtype_vec; -- GitLab