diff --git a/dnn/src/aarch64/relayout/opr_impl.cpp b/dnn/src/aarch64/relayout/opr_impl.cpp index 9031d2c705a352f1246988986aaf472e8b4d1195..cbfece6ed240904523206d99ba7a3212de8b191a 100644 --- a/dnn/src/aarch64/relayout/opr_impl.cpp +++ b/dnn/src/aarch64/relayout/opr_impl.cpp @@ -14,6 +14,7 @@ #include "src/aarch64/handle.h" #include "src/aarch64/relayout/opr_impl.h" +#include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; using namespace relayout; @@ -131,6 +132,179 @@ void trans_16x16_u8( "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); } +struct Transpose4Byte { + uint32_t v; +}; + +static inline void trans_8x8_u32( + const void* src, void* dst, const size_t src_step, const size_t dst_step) { + uint32_t* src_ptr = (uint32_t*)src; + uint32_t* dst_ptr = (uint32_t*)dst; + uint32x4x2_t src0 = vld1q_u32_x2(src_ptr + 0 * src_step); // A0A1A2A3 + uint32x4x2_t src1 = vld1q_u32_x2(src_ptr + 1 * src_step); // B0B1B2B3 + uint32x4x2_t src2 = vld1q_u32_x2(src_ptr + 2 * src_step); // C0C1C2C3 + uint32x4x2_t src3 = vld1q_u32_x2(src_ptr + 3 * src_step); // D0D1D2D3 + uint32x4x2_t src4 = vld1q_u32_x2(src_ptr + 4 * src_step); // E0E1E2E3 + uint32x4x2_t src5 = vld1q_u32_x2(src_ptr + 5 * src_step); // F0F1F2F3 + uint32x4x2_t src6 = vld1q_u32_x2(src_ptr + 6 * src_step); // G0G1G2G3 + uint32x4x2_t src7 = vld1q_u32_x2(src_ptr + 7 * src_step); // H0H1H2H3 + + uint32x4_t ab_low = vzip1q_u32(src0.val[0], src1.val[0]); // A0B0A1B1 + uint32x4_t ab_high = vzip2q_u32(src0.val[0], src1.val[0]); // A2B2A3B3 + uint32x4_t cd_low = vzip1q_u32(src2.val[0], src3.val[0]); // C0D0C1D1 + uint32x4_t cd_high = vzip2q_u32(src2.val[0], src3.val[0]); // C2D2C3D3 + uint32x4_t ef_low = vzip1q_u32(src4.val[0], src5.val[0]); // E0F0E1F1 + uint32x4_t ef_high = vzip2q_u32(src4.val[0], src5.val[0]); // E2F2E3F3 + uint32x4_t gh_low = vzip1q_u32(src6.val[0], src7.val[0]); // G0H0G1H1 + uint32x4_t gh_high = vzip2q_u32(src6.val[0], src7.val[0]); // G2H2G3H3 + + uint32x4_t abcd_0 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0 + uint32x4_t abcd_1 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1 + uint32x4_t abcd_2 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ab_high), + vreinterpretq_u64_u32(cd_high))); // A2B2C2D2 + uint32x4_t abcd_3 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ab_high), + vreinterpretq_u64_u32(cd_high))); // A3B3C3D3 + uint32x4_t efgh_0 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0 + uint32x4_t efgh_1 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1 + uint32x4_t efgh_2 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ef_high), + vreinterpretq_u64_u32(gh_high))); // E2F2G2H2 + uint32x4_t efgh_3 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ef_high), + vreinterpretq_u64_u32(gh_high))); // E3F3G3H3 + + vst1q_u32(dst_ptr + 0 * dst_step, abcd_0); + vst1q_u32(dst_ptr + 0 * dst_step + 4, efgh_0); + vst1q_u32(dst_ptr + 1 * dst_step, abcd_1); + vst1q_u32(dst_ptr + 1 * dst_step + 4, efgh_1); + vst1q_u32(dst_ptr + 2 * dst_step, abcd_2); + vst1q_u32(dst_ptr + 2 * dst_step + 4, efgh_2); + vst1q_u32(dst_ptr + 3 * dst_step, abcd_3); + vst1q_u32(dst_ptr + 3 * dst_step + 4, efgh_3); + + ab_low = vzip1q_u32(src0.val[1], src1.val[1]); // A0B0A1B1 + ab_high = vzip2q_u32(src0.val[1], src1.val[1]); // A2B2A3B3 + cd_low = vzip1q_u32(src2.val[1], src3.val[1]); // C0D0C1D1 + cd_high = vzip2q_u32(src2.val[1], src3.val[1]); // C2D2C3D3 + ef_low = vzip1q_u32(src4.val[1], src5.val[1]); // E0F0E1F1 + ef_high = vzip2q_u32(src4.val[1], src5.val[1]); // E2F2E3F3 + gh_low = vzip1q_u32(src6.val[1], src7.val[1]); // G0H0G1H1 + gh_high = vzip2q_u32(src6.val[1], src7.val[1]); // G2H2G3H3 + + abcd_0 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0 + abcd_1 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1 + abcd_2 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ab_high), + vreinterpretq_u64_u32(cd_high))); // A2B2C2D2 + abcd_3 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ab_high), + vreinterpretq_u64_u32(cd_high))); // A3B3C3D3 + efgh_0 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0 + efgh_1 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1 + efgh_2 = vreinterpretq_u32_u64(vzip1q_u64( + vreinterpretq_u64_u32(ef_high), + vreinterpretq_u64_u32(gh_high))); // E2F2G2H2 + efgh_3 = vreinterpretq_u32_u64(vzip2q_u64( + vreinterpretq_u64_u32(ef_high), + vreinterpretq_u64_u32(gh_high))); // E3F3G3H3 + + vst1q_u32(dst_ptr + 4 * dst_step, abcd_0); + vst1q_u32(dst_ptr + 4 * dst_step + 4, efgh_0); + vst1q_u32(dst_ptr + 5 * dst_step, abcd_1); + vst1q_u32(dst_ptr + 5 * dst_step + 4, efgh_1); + vst1q_u32(dst_ptr + 6 * dst_step, abcd_2); + vst1q_u32(dst_ptr + 6 * dst_step + 4, efgh_2); + vst1q_u32(dst_ptr + 7 * dst_step, abcd_3); + vst1q_u32(dst_ptr + 7 * dst_step + 4, efgh_3); +} + +struct Transpose2Byte { + uint16_t v; +}; +static inline void trans_8x8_u16( + const void* src, void* dst, const size_t src_step, const size_t dst_step) { + uint16_t* src_ptr = (uint16_t*)src; + uint16_t* dst_ptr = (uint16_t*)dst; + uint16x8_t src0 = vld1q_u16(src_ptr + 0 * src_step); // A0A1A2A3A4A5A6A7 + uint16x8_t src1 = vld1q_u16(src_ptr + 1 * src_step); // B0B1B2B3B4B5B6B7 + uint16x8_t src2 = vld1q_u16(src_ptr + 2 * src_step); // C0C1C2C3C4C5C6C7 + uint16x8_t src3 = vld1q_u16(src_ptr + 3 * src_step); // D0D1D2D3D4D5D6D7 + uint16x8_t src4 = vld1q_u16(src_ptr + 4 * src_step); // E0E1E2E3E4E5E6E7 + uint16x8_t src5 = vld1q_u16(src_ptr + 5 * src_step); // F0F1F2F3F4F5F6F7 + uint16x8_t src6 = vld1q_u16(src_ptr + 6 * src_step); // G0G1G2G3G4G5G6G7 + uint16x8_t src7 = vld1q_u16(src_ptr + 7 * src_step); // H0H1H2H3H4H5H6H7 + + uint16x8_t ab_low = vzip1q_u16(src0, src1); // A0B0A1B1A2B2A3B3 + uint16x8_t ab_high = vzip2q_u16(src0, src1); // A4B4A5B5A6B6A7B7 + uint16x8_t cd_low = vzip1q_u16(src2, src3); // C0D0C1D1C2D2C3D3 + uint16x8_t cd_high = vzip2q_u16(src2, src3); // C4D4C5D5C6D6C7D7 + uint16x8_t ef_low = vzip1q_u16(src4, src5); // E0F0E1F1E2F2E3F3 + uint16x8_t ef_high = vzip2q_u16(src4, src5); // E4F4E5F5E6F6E7F7 + uint16x8_t gh_low = vzip1q_u16(src6, src7); // G0H0G1H1G2H2G3H3 + uint16x8_t gh_high = vzip2q_u16(src6, src7); // G4H4G5H5G6H6G7H7 + + uint16x8_t abcd_0 = vreinterpretq_u16_u32(vzip1q_u32( + vreinterpretq_u32_u16(ab_low), + vreinterpretq_u32_u16(cd_low))); // A0B0C0D0A1B1C1D1 + uint16x8_t abcd_2 = vreinterpretq_u16_u32(vzip2q_u32( + vreinterpretq_u32_u16(ab_low), + vreinterpretq_u32_u16(cd_low))); // A2B2C2D2A3B3C3D3 + uint16x8_t abcd_4 = vreinterpretq_u16_u32(vzip1q_u32( + vreinterpretq_u32_u16(ab_high), + vreinterpretq_u32_u16(cd_high))); // A4B4C4D4A5B5C5D5 + uint16x8_t abcd_6 = vreinterpretq_u16_u32(vzip2q_u32( + vreinterpretq_u32_u16(ab_high), + vreinterpretq_u32_u16(cd_high))); // A6B6C6D6A7B7C7D7 + uint16x8_t efgh_0 = vreinterpretq_u16_u32(vzip1q_u32( + vreinterpretq_u32_u16(ef_low), + vreinterpretq_u32_u16(gh_low))); // E0F0G0H0E1F1G1H1 + uint16x8_t efgh_2 = vreinterpretq_u16_u32(vzip2q_u32( + vreinterpretq_u32_u16(ef_low), + vreinterpretq_u32_u16(gh_low))); // E2F2G2H2E3F3G3H3 + uint16x8_t efgh_4 = vreinterpretq_u16_u32(vzip1q_u32( + vreinterpretq_u32_u16(ef_high), + vreinterpretq_u32_u16(gh_high))); // E4F4G4H4E5F5G5H5 + uint16x8_t efgh_6 = vreinterpretq_u16_u32(vzip2q_u32( + vreinterpretq_u32_u16(ef_high), + vreinterpretq_u32_u16(gh_high))); // E6F6G6H6E7F7G7H7 + + uint16x8_t row_0 = vreinterpretq_u16_u64( + vzip1q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0))); + uint16x8_t row_1 = vreinterpretq_u16_u64( + vzip2q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0))); + uint16x8_t row_2 = vreinterpretq_u16_u64( + vzip1q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2))); + uint16x8_t row_3 = vreinterpretq_u16_u64( + vzip2q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2))); + uint16x8_t row_4 = vreinterpretq_u16_u64( + vzip1q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4))); + uint16x8_t row_5 = vreinterpretq_u16_u64( + vzip2q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4))); + uint16x8_t row_6 = vreinterpretq_u16_u64( + vzip1q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6))); + uint16x8_t row_7 = vreinterpretq_u16_u64( + vzip2q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6))); + + vst1q_u16(dst_ptr + 0 * dst_step, row_0); + vst1q_u16(dst_ptr + 1 * dst_step, row_1); + vst1q_u16(dst_ptr + 2 * dst_step, row_2); + vst1q_u16(dst_ptr + 3 * dst_step, row_3); + vst1q_u16(dst_ptr + 4 * dst_step, row_4); + vst1q_u16(dst_ptr + 5 * dst_step, row_5); + vst1q_u16(dst_ptr + 6 * dst_step, row_6); + vst1q_u16(dst_ptr + 7 * dst_step, row_7); +} + } // anonymous namespace namespace megdnn { @@ -148,6 +322,30 @@ void transpose_block( trans_16x16_u8(src, dst, src_stride, dst_stride); } +template <> +struct transpose_traits { + static constexpr size_t block_size = 8; +}; + +template <> +void transpose_block( + const Transpose4Byte* src, Transpose4Byte* dst, const size_t src_stride, + const size_t dst_stride) { + trans_8x8_u32(src, dst, src_stride, dst_stride); +} + +template <> +struct transpose_traits { + static constexpr size_t block_size = 8; +}; + +template <> +void transpose_block( + const Transpose2Byte* src, Transpose2Byte* dst, const size_t src_stride, + const size_t dst_stride) { + trans_8x8_u16(src, dst, src_stride, dst_stride); +} + } // namespace transpose_fallback } // namespace relayout } // namespace megdnn @@ -164,16 +362,33 @@ void aarch64::RelayoutForwardImpl::exec( fallback::RelayoutForwardImpl::exec(src0, dst0, src_handle); return; } - relayout::TransposeParam trans_param; - bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); + bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { auto sptr = static_cast(src.raw_ptr), dptr = static_cast(dst.raw_ptr); MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose( - trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); + trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, + trans_param.stride_m)); + return; + } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { + auto sptr = static_cast(src.raw_ptr), + dptr = static_cast(dst.raw_ptr); + + MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose( + trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, + trans_param.stride_m)); + return; + } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { + auto sptr = static_cast(src.raw_ptr), + dptr = static_cast(dst.raw_ptr); + + MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose( + trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, + trans_param.stride_m)); return; } + exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); } diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 001749845cdd505a50105ad8184d4b037a3c91ef..c04bb85d6caa5d91be25a258080950694b2f8f77 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -321,6 +321,12 @@ __ai void vst1q_f32_x2(const float* p, float32x4x2_t v) { } #endif +#if !defined(vld1q_u32_x2) && (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 3)) +__ai uint32x4x2_t vld1q_u32_x2(const uint32_t* p) { + return {{vld1q_u32(p), vld1q_u32(p + 4)}}; +} +#endif + __ai int8x16_t vtranslq_s8(int8x8_t a) { int8x16_t ret; #if MEGDNN_AARCH64 diff --git a/dnn/src/common/relayout.cpp b/dnn/src/common/relayout.cpp index eacf6662832dbb372b386eecbfd3912c42efe5e1..14a400f5fc646a36be940e0629043107254775ad 100644 --- a/dnn/src/common/relayout.cpp +++ b/dnn/src/common/relayout.cpp @@ -23,7 +23,8 @@ namespace { //! whether current shape is [b][n][m][c] and is a transpose of contig //! [b][m][n][c] -bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { +bool is_transpose_single( + const TensorLayout& layout, TransposeParam& p, bool allow_no_contig) { /* * assuming contig layout is: * shape: b, m, n, c @@ -42,8 +43,9 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { * * if b == 1 && c == 1: * shape: n, m - * stride: 1, n + * stride: 1, n(stride_m for no-contig) */ + p.stride_m = 0; auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; }; if (layout.ndim == 4) { p.batch = layout[0]; @@ -80,7 +82,15 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { p.n = layout.shape[0]; p.m = layout.shape[1]; p.c = 1; - return strd(0, 1) && strd(1, p.n); + if (strd(0, 1) && strd(1, p.n)) { + return true; + } else if ( + strd(0, 1) && layout.stride[1] > 0 && + (size_t)(layout.stride[1]) >= p.n && allow_no_contig) { + //! stride_m used in no-contig mode, stride_m >= p.n + p.stride_m = layout.stride[1]; + return true; + } } return false; } @@ -98,15 +108,16 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout& } bool relayout::is_transpose( - const TensorLayout& src, const TensorLayout& dst, TransposeParam& p) { - if (is_contig(dst) && is_transpose_single(src, p)) { + const TensorLayout& src, const TensorLayout& dst, TransposeParam& p, + bool allow_non_contig) { + if (is_contig(dst) && is_transpose_single(src, p, allow_non_contig)) { // if the original intention is to transpose (m, n) to (n, m), // then we should use (n, m) as the contig dst and use a corrsponding // non-contig src with the same (n, m) shape (remember relayout is // defined on element correspondence on the logical view) return true; } - if (is_contig(src) && is_transpose_single(dst, p)) { + if (is_contig(src) && is_transpose_single(dst, p, allow_non_contig)) { std::swap(p.m, p.n); return true; } diff --git a/dnn/src/common/relayout_helper.h b/dnn/src/common/relayout_helper.h index 16ff10ce8ab4937ae587d00866ba6aa4b27bd73b..c03f9e53f931d08b4fa7a7959f1129f691b699a6 100644 --- a/dnn/src/common/relayout_helper.h +++ b/dnn/src/common/relayout_helper.h @@ -27,7 +27,7 @@ static inline bool is_contig(const TensorLayout& layout) { //! [b][m][n][c] to [b][n][m][c] struct TransposeParam { - size_t batch, m, n, c; + size_t batch, m, n, c, stride_m; }; /** @@ -36,7 +36,9 @@ struct TransposeParam { * Note that \p src and \p dst should have been processed by * RelayoutForward::check_layout_and_canonize */ -bool is_transpose(const TensorLayout& src, const TensorLayout& dst, TransposeParam& p); +bool is_transpose( + const TensorLayout& src, const TensorLayout& dst, TransposeParam& p, + bool allow_non_contig = false); namespace transpose_fallback { @@ -105,20 +107,23 @@ void transpose_block( * \brief transpose contiguous (batch, m, n) to (batch, n, m) */ template -void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { +void transpose(size_t batch, size_t m, size_t n, T* src, T* dst, size_t stride_m = 0) { + if (stride_m == 0) { + stride_m = n; + } auto batch_src = src; auto batch_dst = dst; constexpr size_t B = transpose_traits::block_size; - auto work_block = [m, n, &batch_src, &batch_dst]( + auto work_block = [m, stride_m, &batch_src, &batch_dst]( const size_t i, const size_t j, const size_t h, const size_t w) { - auto src = batch_src + i * n + j, dst = batch_dst + j * m + i; + auto src = batch_src + i * stride_m + j, dst = batch_dst + j * m + i; MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) { if (h == B && w == B) { - transpose_block(src, dst, n, m); + transpose_block(src, dst, stride_m, m); } else { - transpose_block(src, dst, n, m, h, w); + transpose_block(src, dst, stride_m, m, h, w); } } MIDOUT_END(); @@ -141,7 +146,7 @@ void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { if (i < m) { work_row(i, m - i); } - batch_src += m * n; + batch_src += m * stride_m; batch_dst += m * n; } } diff --git a/dnn/src/fallback/relayout/opr_impl.cpp b/dnn/src/fallback/relayout/opr_impl.cpp index e4caaa9b2932b0881e0cee33cbd39b6300c65e15..d301acbc0f59b517af5884df3923f2bae1c9b8b5 100644 --- a/dnn/src/fallback/relayout/opr_impl.cpp +++ b/dnn/src/fallback/relayout/opr_impl.cpp @@ -48,10 +48,12 @@ void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) { } template -void call_transpose(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { +void call_transpose( + size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, + size_t stride_m) { megdnn_assert(ch == 1); relayout::transpose_fallback::transpose( - batch, m, n, static_cast(src), static_cast(dst)); + batch, m, n, static_cast(src), static_cast(dst), stride_m); } //! one operand contiguous, and the other non-contiguous @@ -186,7 +188,10 @@ void transpose_cv_row( } template -void transpose_cv(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { +void transpose_cv( + size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst, + size_t stride_m) { + megdnn_assert(stride_m == 0); constexpr size_t B = BLOCK_SIZE; auto batch_src = static_cast(src); auto batch_dst = static_cast(dst); @@ -237,7 +242,7 @@ void RelayoutForwardImpl::exec( } relayout::TransposeParam trans_param; - bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); + bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); } @@ -245,7 +250,7 @@ void RelayoutForwardImpl::exec_after_preprocess( const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) { if (transpose) { auto dsize = src.layout.dtype.size() * transpose->c; - void (*kptr)(size_t, size_t, size_t, size_t, void*, void*) = nullptr; + void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) = nullptr; auto src_addr = reinterpret_cast(src.raw_ptr), dst_addr = reinterpret_cast(dst.raw_ptr); if (dsize == 1) { @@ -293,7 +298,9 @@ void RelayoutForwardImpl::exec_after_preprocess( if (kptr) { auto kern = [t = *transpose, sptr = src.raw_ptr, dptr = dst.raw_ptr, - kptr]() { kptr(t.batch, t.m, t.n, t.c, sptr, dptr); }; + kptr]() { + kptr(t.batch, t.m, t.n, t.c, sptr, dptr, t.stride_m); + }; static_cast(handle())->dispatch_kern(kern); return; } else { diff --git a/dnn/test/aarch64/fixture.cpp b/dnn/test/aarch64/fixture.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2357b9f94a97cadac4621ba35595bfc66151f9e6 --- /dev/null +++ b/dnn/test/aarch64/fixture.cpp @@ -0,0 +1,29 @@ +/** + * \file dnn/test/aarch64/fixture.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/aarch64/fixture.h" + +#include "test/common/memory_manager.h" +#include "test/common/random_state.h" +#include "test/common/utils.h" + +namespace megdnn { +namespace test { + +Handle* AARCH64::fallback_handle() { + if (!m_fallback_handle) { + m_fallback_handle = create_cpu_handle(1); + } + return m_fallback_handle.get(); +} + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/aarch64/fixture.h b/dnn/test/aarch64/fixture.h index 25dd3b3d40f01895bc604460c053217bcccb7d1d..5dc7677ae394d432da37fc28df56592c1178b7b2 100644 --- a/dnn/test/aarch64/fixture.h +++ b/dnn/test/aarch64/fixture.h @@ -19,7 +19,13 @@ namespace megdnn { namespace test { -class AARCH64 : public ARM_COMMON {}; +class AARCH64 : public ARM_COMMON { +public: + Handle* fallback_handle(); + +private: + std::unique_ptr m_handle, m_fallback_handle; +}; class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; diff --git a/dnn/test/aarch64/relayout.cpp b/dnn/test/aarch64/relayout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..31e10eb02e3097fa4e9182c671242f49bf0aa80c --- /dev/null +++ b/dnn/test/aarch64/relayout.cpp @@ -0,0 +1,118 @@ +/** + * \file dnn/test/aarch64/relayout.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "test/aarch64/fixture.h" +#include "test/common/benchmarker.h" + +#include "test/common/checker.h" +#include "test/common/relayout.h" +#include "test/common/rng.h" + +namespace megdnn { +namespace test { + +namespace { +template +class AARCH64_RELAYOUT : public AARCH64 {}; +TYPED_TEST_CASE(AARCH64_RELAYOUT, relayout::test_types); +TYPED_TEST(AARCH64_RELAYOUT, run) { + relayout::run_test(this->handle()); +} +} // namespace + +TEST_F(AARCH64, Relayout) { + Checker checker(handle()); + std::vector<::megdnn::DType> dtype_vec; + dtype_vec.push_back(dtype::Float32()); + dtype_vec.push_back(dtype::Int16()); + dtype_vec.push_back(dtype::Uint16()); + dtype_vec.push_back(dtype::Int8()); + for (auto dtype : dtype_vec) { + TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype); + TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); + checker.execl({src, dst}); + } +} + +TEST_F(AARCH64, RelayoutBig) { + Checker checker(handle()); + ConsecutiveRNG rng; + checker.set_rng(0, &rng); + int m = 512; + int n = 512; + TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype::Float32()); + TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype::Float32()); + checker.execl({src, dst}); +} + +#if MEGDNN_WITH_BENCHMARK + +TEST_F(AARCH64, BENCHMARK_Relayout) { + constexpr size_t WARM_RUNS = 100; + constexpr size_t RUNS = 600; + auto dtype = dtype::Float32(); + Benchmarker benchmarker_relayout(handle()); + Benchmarker benchmarker_fbk_relayout(fallback_handle()); + benchmarker_relayout.set_times(WARM_RUNS); + benchmarker_fbk_relayout.set_times(WARM_RUNS); + int m = 512; + int n = 512; + TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype); + TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype); + TensorLayoutArray tensor_case; + tensor_case.push_back(src); + tensor_case.push_back(dst); + + benchmarker_relayout.exec(tensor_case); + benchmarker_fbk_relayout.exec(tensor_case); + benchmarker_relayout.set_times(RUNS); + benchmarker_fbk_relayout.set_times(RUNS); + + auto used = benchmarker_relayout.exec(tensor_case) / RUNS; + auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS; + float bw = 2.f * m * n * 1e-6 / used * dtype.size(); + float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size(); + printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n", + src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw); +} + +TEST_F(AARCH64, BENCHMARK_Relayout_2) { + constexpr size_t WARM_RUNS = 100; + constexpr size_t RUNS = 600; + auto dtype = dtype::Float32(); + Benchmarker benchmarker_relayout(handle()); + Benchmarker benchmarker_fbk_relayout(fallback_handle()); + benchmarker_relayout.set_times(WARM_RUNS); + benchmarker_fbk_relayout.set_times(WARM_RUNS); + int m = 54; + int n = 28762; + TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype); + TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype); + TensorLayoutArray tensor_case; + tensor_case.push_back(src); + tensor_case.push_back(dst); + + benchmarker_relayout.exec(tensor_case); + benchmarker_fbk_relayout.exec(tensor_case); + benchmarker_relayout.set_times(RUNS); + benchmarker_fbk_relayout.set_times(RUNS); + + auto used = benchmarker_relayout.exec(tensor_case) / RUNS; + auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS; + float bw = 2.f * m * n * 1e-6 / used * dtype.size(); + float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size(); + printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n", + src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw); +} +#endif + +} // namespace test +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/relayout.cpp b/dnn/test/common/relayout.cpp index 0dd249f4a87891e921c15c7ff153a0dcab05150a..2b4f33af50e923021e3dee288b0de75b49eefba5 100644 --- a/dnn/test/common/relayout.cpp +++ b/dnn/test/common/relayout.cpp @@ -180,11 +180,11 @@ TEST(RELAYOUT, TRANSPOSE_DET) { ASSERT_EQ(p_get.c, p.c); } }; - run({2, 3}, {1, 0}, true, {1, 2, 3, 1}); - run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5}); - run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1}); - run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5}); - run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1}); + run({2, 3}, {1, 0}, true, {1, 2, 3, 1, 0}); + run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5, 0}); + run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1, 0}); + run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5, 0}); + run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1, 0}); run({2, 3, 5}, {2, 1, 0}, false); run({3, 2, 3, 5}, {3, 2, 1, 0}, false); }