提交 056fd6bc 编写于 作者: M Megvii Engine Team

feat(dnn/arm64): support stride_m in arm64 relayout

GitOrigin-RevId: c74193a23dfb3b0eecd4e56c0ec52a54690fbf71
上级 ec75cd86
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "src/aarch64/handle.h" #include "src/aarch64/handle.h"
#include "src/aarch64/relayout/opr_impl.h" #include "src/aarch64/relayout/opr_impl.h"
#include "src/arm_common/simd_macro/marm_neon.h"
using namespace megdnn; using namespace megdnn;
using namespace relayout; using namespace relayout;
...@@ -131,6 +132,179 @@ void trans_16x16_u8( ...@@ -131,6 +132,179 @@ void trans_16x16_u8(
"d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31"); "d22", "d23", "d24", "d25", "d26", "d27", "d28", "d29", "d30", "d31");
} }
struct Transpose4Byte {
uint32_t v;
};
static inline void trans_8x8_u32(
const void* src, void* dst, const size_t src_step, const size_t dst_step) {
uint32_t* src_ptr = (uint32_t*)src;
uint32_t* dst_ptr = (uint32_t*)dst;
uint32x4x2_t src0 = vld1q_u32_x2(src_ptr + 0 * src_step); // A0A1A2A3
uint32x4x2_t src1 = vld1q_u32_x2(src_ptr + 1 * src_step); // B0B1B2B3
uint32x4x2_t src2 = vld1q_u32_x2(src_ptr + 2 * src_step); // C0C1C2C3
uint32x4x2_t src3 = vld1q_u32_x2(src_ptr + 3 * src_step); // D0D1D2D3
uint32x4x2_t src4 = vld1q_u32_x2(src_ptr + 4 * src_step); // E0E1E2E3
uint32x4x2_t src5 = vld1q_u32_x2(src_ptr + 5 * src_step); // F0F1F2F3
uint32x4x2_t src6 = vld1q_u32_x2(src_ptr + 6 * src_step); // G0G1G2G3
uint32x4x2_t src7 = vld1q_u32_x2(src_ptr + 7 * src_step); // H0H1H2H3
uint32x4_t ab_low = vzip1q_u32(src0.val[0], src1.val[0]); // A0B0A1B1
uint32x4_t ab_high = vzip2q_u32(src0.val[0], src1.val[0]); // A2B2A3B3
uint32x4_t cd_low = vzip1q_u32(src2.val[0], src3.val[0]); // C0D0C1D1
uint32x4_t cd_high = vzip2q_u32(src2.val[0], src3.val[0]); // C2D2C3D3
uint32x4_t ef_low = vzip1q_u32(src4.val[0], src5.val[0]); // E0F0E1F1
uint32x4_t ef_high = vzip2q_u32(src4.val[0], src5.val[0]); // E2F2E3F3
uint32x4_t gh_low = vzip1q_u32(src6.val[0], src7.val[0]); // G0H0G1H1
uint32x4_t gh_high = vzip2q_u32(src6.val[0], src7.val[0]); // G2H2G3H3
uint32x4_t abcd_0 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0
uint32x4_t abcd_1 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1
uint32x4_t abcd_2 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ab_high),
vreinterpretq_u64_u32(cd_high))); // A2B2C2D2
uint32x4_t abcd_3 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ab_high),
vreinterpretq_u64_u32(cd_high))); // A3B3C3D3
uint32x4_t efgh_0 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0
uint32x4_t efgh_1 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1
uint32x4_t efgh_2 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ef_high),
vreinterpretq_u64_u32(gh_high))); // E2F2G2H2
uint32x4_t efgh_3 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ef_high),
vreinterpretq_u64_u32(gh_high))); // E3F3G3H3
vst1q_u32(dst_ptr + 0 * dst_step, abcd_0);
vst1q_u32(dst_ptr + 0 * dst_step + 4, efgh_0);
vst1q_u32(dst_ptr + 1 * dst_step, abcd_1);
vst1q_u32(dst_ptr + 1 * dst_step + 4, efgh_1);
vst1q_u32(dst_ptr + 2 * dst_step, abcd_2);
vst1q_u32(dst_ptr + 2 * dst_step + 4, efgh_2);
vst1q_u32(dst_ptr + 3 * dst_step, abcd_3);
vst1q_u32(dst_ptr + 3 * dst_step + 4, efgh_3);
ab_low = vzip1q_u32(src0.val[1], src1.val[1]); // A0B0A1B1
ab_high = vzip2q_u32(src0.val[1], src1.val[1]); // A2B2A3B3
cd_low = vzip1q_u32(src2.val[1], src3.val[1]); // C0D0C1D1
cd_high = vzip2q_u32(src2.val[1], src3.val[1]); // C2D2C3D3
ef_low = vzip1q_u32(src4.val[1], src5.val[1]); // E0F0E1F1
ef_high = vzip2q_u32(src4.val[1], src5.val[1]); // E2F2E3F3
gh_low = vzip1q_u32(src6.val[1], src7.val[1]); // G0H0G1H1
gh_high = vzip2q_u32(src6.val[1], src7.val[1]); // G2H2G3H3
abcd_0 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A0B0C0D0
abcd_1 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ab_low), vreinterpretq_u64_u32(cd_low))); // A1B1C1D1
abcd_2 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ab_high),
vreinterpretq_u64_u32(cd_high))); // A2B2C2D2
abcd_3 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ab_high),
vreinterpretq_u64_u32(cd_high))); // A3B3C3D3
efgh_0 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E0F0G0H0
efgh_1 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ef_low), vreinterpretq_u64_u32(gh_low))); // E1F1G1H1
efgh_2 = vreinterpretq_u32_u64(vzip1q_u64(
vreinterpretq_u64_u32(ef_high),
vreinterpretq_u64_u32(gh_high))); // E2F2G2H2
efgh_3 = vreinterpretq_u32_u64(vzip2q_u64(
vreinterpretq_u64_u32(ef_high),
vreinterpretq_u64_u32(gh_high))); // E3F3G3H3
vst1q_u32(dst_ptr + 4 * dst_step, abcd_0);
vst1q_u32(dst_ptr + 4 * dst_step + 4, efgh_0);
vst1q_u32(dst_ptr + 5 * dst_step, abcd_1);
vst1q_u32(dst_ptr + 5 * dst_step + 4, efgh_1);
vst1q_u32(dst_ptr + 6 * dst_step, abcd_2);
vst1q_u32(dst_ptr + 6 * dst_step + 4, efgh_2);
vst1q_u32(dst_ptr + 7 * dst_step, abcd_3);
vst1q_u32(dst_ptr + 7 * dst_step + 4, efgh_3);
}
struct Transpose2Byte {
uint16_t v;
};
static inline void trans_8x8_u16(
const void* src, void* dst, const size_t src_step, const size_t dst_step) {
uint16_t* src_ptr = (uint16_t*)src;
uint16_t* dst_ptr = (uint16_t*)dst;
uint16x8_t src0 = vld1q_u16(src_ptr + 0 * src_step); // A0A1A2A3A4A5A6A7
uint16x8_t src1 = vld1q_u16(src_ptr + 1 * src_step); // B0B1B2B3B4B5B6B7
uint16x8_t src2 = vld1q_u16(src_ptr + 2 * src_step); // C0C1C2C3C4C5C6C7
uint16x8_t src3 = vld1q_u16(src_ptr + 3 * src_step); // D0D1D2D3D4D5D6D7
uint16x8_t src4 = vld1q_u16(src_ptr + 4 * src_step); // E0E1E2E3E4E5E6E7
uint16x8_t src5 = vld1q_u16(src_ptr + 5 * src_step); // F0F1F2F3F4F5F6F7
uint16x8_t src6 = vld1q_u16(src_ptr + 6 * src_step); // G0G1G2G3G4G5G6G7
uint16x8_t src7 = vld1q_u16(src_ptr + 7 * src_step); // H0H1H2H3H4H5H6H7
uint16x8_t ab_low = vzip1q_u16(src0, src1); // A0B0A1B1A2B2A3B3
uint16x8_t ab_high = vzip2q_u16(src0, src1); // A4B4A5B5A6B6A7B7
uint16x8_t cd_low = vzip1q_u16(src2, src3); // C0D0C1D1C2D2C3D3
uint16x8_t cd_high = vzip2q_u16(src2, src3); // C4D4C5D5C6D6C7D7
uint16x8_t ef_low = vzip1q_u16(src4, src5); // E0F0E1F1E2F2E3F3
uint16x8_t ef_high = vzip2q_u16(src4, src5); // E4F4E5F5E6F6E7F7
uint16x8_t gh_low = vzip1q_u16(src6, src7); // G0H0G1H1G2H2G3H3
uint16x8_t gh_high = vzip2q_u16(src6, src7); // G4H4G5H5G6H6G7H7
uint16x8_t abcd_0 = vreinterpretq_u16_u32(vzip1q_u32(
vreinterpretq_u32_u16(ab_low),
vreinterpretq_u32_u16(cd_low))); // A0B0C0D0A1B1C1D1
uint16x8_t abcd_2 = vreinterpretq_u16_u32(vzip2q_u32(
vreinterpretq_u32_u16(ab_low),
vreinterpretq_u32_u16(cd_low))); // A2B2C2D2A3B3C3D3
uint16x8_t abcd_4 = vreinterpretq_u16_u32(vzip1q_u32(
vreinterpretq_u32_u16(ab_high),
vreinterpretq_u32_u16(cd_high))); // A4B4C4D4A5B5C5D5
uint16x8_t abcd_6 = vreinterpretq_u16_u32(vzip2q_u32(
vreinterpretq_u32_u16(ab_high),
vreinterpretq_u32_u16(cd_high))); // A6B6C6D6A7B7C7D7
uint16x8_t efgh_0 = vreinterpretq_u16_u32(vzip1q_u32(
vreinterpretq_u32_u16(ef_low),
vreinterpretq_u32_u16(gh_low))); // E0F0G0H0E1F1G1H1
uint16x8_t efgh_2 = vreinterpretq_u16_u32(vzip2q_u32(
vreinterpretq_u32_u16(ef_low),
vreinterpretq_u32_u16(gh_low))); // E2F2G2H2E3F3G3H3
uint16x8_t efgh_4 = vreinterpretq_u16_u32(vzip1q_u32(
vreinterpretq_u32_u16(ef_high),
vreinterpretq_u32_u16(gh_high))); // E4F4G4H4E5F5G5H5
uint16x8_t efgh_6 = vreinterpretq_u16_u32(vzip2q_u32(
vreinterpretq_u32_u16(ef_high),
vreinterpretq_u32_u16(gh_high))); // E6F6G6H6E7F7G7H7
uint16x8_t row_0 = vreinterpretq_u16_u64(
vzip1q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0)));
uint16x8_t row_1 = vreinterpretq_u16_u64(
vzip2q_u64(vreinterpretq_u64_u16(abcd_0), vreinterpretq_u64_u16(efgh_0)));
uint16x8_t row_2 = vreinterpretq_u16_u64(
vzip1q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2)));
uint16x8_t row_3 = vreinterpretq_u16_u64(
vzip2q_u64(vreinterpretq_u64_u16(abcd_2), vreinterpretq_u64_u16(efgh_2)));
uint16x8_t row_4 = vreinterpretq_u16_u64(
vzip1q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4)));
uint16x8_t row_5 = vreinterpretq_u16_u64(
vzip2q_u64(vreinterpretq_u64_u16(abcd_4), vreinterpretq_u64_u16(efgh_4)));
uint16x8_t row_6 = vreinterpretq_u16_u64(
vzip1q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6)));
uint16x8_t row_7 = vreinterpretq_u16_u64(
vzip2q_u64(vreinterpretq_u64_u16(abcd_6), vreinterpretq_u64_u16(efgh_6)));
vst1q_u16(dst_ptr + 0 * dst_step, row_0);
vst1q_u16(dst_ptr + 1 * dst_step, row_1);
vst1q_u16(dst_ptr + 2 * dst_step, row_2);
vst1q_u16(dst_ptr + 3 * dst_step, row_3);
vst1q_u16(dst_ptr + 4 * dst_step, row_4);
vst1q_u16(dst_ptr + 5 * dst_step, row_5);
vst1q_u16(dst_ptr + 6 * dst_step, row_6);
vst1q_u16(dst_ptr + 7 * dst_step, row_7);
}
} // anonymous namespace } // anonymous namespace
namespace megdnn { namespace megdnn {
...@@ -148,6 +322,30 @@ void transpose_block<TransposeByte>( ...@@ -148,6 +322,30 @@ void transpose_block<TransposeByte>(
trans_16x16_u8(src, dst, src_stride, dst_stride); trans_16x16_u8(src, dst, src_stride, dst_stride);
} }
template <>
struct transpose_traits<Transpose4Byte> {
static constexpr size_t block_size = 8;
};
template <>
void transpose_block<Transpose4Byte>(
const Transpose4Byte* src, Transpose4Byte* dst, const size_t src_stride,
const size_t dst_stride) {
trans_8x8_u32(src, dst, src_stride, dst_stride);
}
template <>
struct transpose_traits<Transpose2Byte> {
static constexpr size_t block_size = 8;
};
template <>
void transpose_block<Transpose2Byte>(
const Transpose2Byte* src, Transpose2Byte* dst, const size_t src_stride,
const size_t dst_stride) {
trans_8x8_u16(src, dst, src_stride, dst_stride);
}
} // namespace transpose_fallback } // namespace transpose_fallback
} // namespace relayout } // namespace relayout
} // namespace megdnn } // namespace megdnn
...@@ -164,16 +362,33 @@ void aarch64::RelayoutForwardImpl::exec( ...@@ -164,16 +362,33 @@ void aarch64::RelayoutForwardImpl::exec(
fallback::RelayoutForwardImpl::exec(src0, dst0, src_handle); fallback::RelayoutForwardImpl::exec(src0, dst0, src_handle);
return; return;
} }
relayout::TransposeParam trans_param; relayout::TransposeParam trans_param;
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true);
if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) {
auto sptr = static_cast<TransposeByte*>(src.raw_ptr), auto sptr = static_cast<TransposeByte*>(src.raw_ptr),
dptr = static_cast<TransposeByte*>(dst.raw_ptr); dptr = static_cast<TransposeByte*>(dst.raw_ptr);
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>(
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); trans_param.batch, trans_param.m, trans_param.n, sptr, dptr,
trans_param.stride_m));
return;
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) {
auto sptr = static_cast<Transpose2Byte*>(src.raw_ptr),
dptr = static_cast<Transpose2Byte*>(dst.raw_ptr);
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>(
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr,
trans_param.stride_m));
return;
} else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) {
auto sptr = static_cast<Transpose4Byte*>(src.raw_ptr),
dptr = static_cast<Transpose4Byte*>(dst.raw_ptr);
MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>(
trans_param.batch, trans_param.m, trans_param.n, sptr, dptr,
trans_param.stride_m));
return; return;
} }
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); exec_after_preprocess(src, dst, trans ? &trans_param : nullptr);
} }
......
...@@ -321,6 +321,12 @@ __ai void vst1q_f32_x2(const float* p, float32x4x2_t v) { ...@@ -321,6 +321,12 @@ __ai void vst1q_f32_x2(const float* p, float32x4x2_t v) {
} }
#endif #endif
#if !defined(vld1q_u32_x2) && (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 3))
__ai uint32x4x2_t vld1q_u32_x2(const uint32_t* p) {
return {{vld1q_u32(p), vld1q_u32(p + 4)}};
}
#endif
__ai int8x16_t vtranslq_s8(int8x8_t a) { __ai int8x16_t vtranslq_s8(int8x8_t a) {
int8x16_t ret; int8x16_t ret;
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
......
...@@ -23,7 +23,8 @@ namespace { ...@@ -23,7 +23,8 @@ namespace {
//! whether current shape is [b][n][m][c] and is a transpose of contig //! whether current shape is [b][n][m][c] and is a transpose of contig
//! [b][m][n][c] //! [b][m][n][c]
bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { bool is_transpose_single(
const TensorLayout& layout, TransposeParam& p, bool allow_no_contig) {
/* /*
* assuming contig layout is: * assuming contig layout is:
* shape: b, m, n, c * shape: b, m, n, c
...@@ -42,8 +43,9 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { ...@@ -42,8 +43,9 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) {
* *
* if b == 1 && c == 1: * if b == 1 && c == 1:
* shape: n, m * shape: n, m
* stride: 1, n * stride: 1, n(stride_m for no-contig)
*/ */
p.stride_m = 0;
auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; }; auto strd = [&](size_t idx, ptrdiff_t v) { return layout.stride[idx] == v; };
if (layout.ndim == 4) { if (layout.ndim == 4) {
p.batch = layout[0]; p.batch = layout[0];
...@@ -80,7 +82,15 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) { ...@@ -80,7 +82,15 @@ bool is_transpose_single(const TensorLayout& layout, TransposeParam& p) {
p.n = layout.shape[0]; p.n = layout.shape[0];
p.m = layout.shape[1]; p.m = layout.shape[1];
p.c = 1; p.c = 1;
return strd(0, 1) && strd(1, p.n); if (strd(0, 1) && strd(1, p.n)) {
return true;
} else if (
strd(0, 1) && layout.stride[1] > 0 &&
(size_t)(layout.stride[1]) >= p.n && allow_no_contig) {
//! stride_m used in no-contig mode, stride_m >= p.n
p.stride_m = layout.stride[1];
return true;
}
} }
return false; return false;
} }
...@@ -98,15 +108,16 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout& ...@@ -98,15 +108,16 @@ void RelayoutForward::check_layout_and_canonize(TensorLayout& src, TensorLayout&
} }
bool relayout::is_transpose( bool relayout::is_transpose(
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p) { const TensorLayout& src, const TensorLayout& dst, TransposeParam& p,
if (is_contig(dst) && is_transpose_single(src, p)) { bool allow_non_contig) {
if (is_contig(dst) && is_transpose_single(src, p, allow_non_contig)) {
// if the original intention is to transpose (m, n) to (n, m), // if the original intention is to transpose (m, n) to (n, m),
// then we should use (n, m) as the contig dst and use a corrsponding // then we should use (n, m) as the contig dst and use a corrsponding
// non-contig src with the same (n, m) shape (remember relayout is // non-contig src with the same (n, m) shape (remember relayout is
// defined on element correspondence on the logical view) // defined on element correspondence on the logical view)
return true; return true;
} }
if (is_contig(src) && is_transpose_single(dst, p)) { if (is_contig(src) && is_transpose_single(dst, p, allow_non_contig)) {
std::swap(p.m, p.n); std::swap(p.m, p.n);
return true; return true;
} }
......
...@@ -27,7 +27,7 @@ static inline bool is_contig(const TensorLayout& layout) { ...@@ -27,7 +27,7 @@ static inline bool is_contig(const TensorLayout& layout) {
//! [b][m][n][c] to [b][n][m][c] //! [b][m][n][c] to [b][n][m][c]
struct TransposeParam { struct TransposeParam {
size_t batch, m, n, c; size_t batch, m, n, c, stride_m;
}; };
/** /**
...@@ -36,7 +36,9 @@ struct TransposeParam { ...@@ -36,7 +36,9 @@ struct TransposeParam {
* Note that \p src and \p dst should have been processed by * Note that \p src and \p dst should have been processed by
* RelayoutForward::check_layout_and_canonize * RelayoutForward::check_layout_and_canonize
*/ */
bool is_transpose(const TensorLayout& src, const TensorLayout& dst, TransposeParam& p); bool is_transpose(
const TensorLayout& src, const TensorLayout& dst, TransposeParam& p,
bool allow_non_contig = false);
namespace transpose_fallback { namespace transpose_fallback {
...@@ -105,20 +107,23 @@ void transpose_block( ...@@ -105,20 +107,23 @@ void transpose_block(
* \brief transpose contiguous (batch, m, n) to (batch, n, m) * \brief transpose contiguous (batch, m, n) to (batch, n, m)
*/ */
template <typename T> template <typename T>
void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { void transpose(size_t batch, size_t m, size_t n, T* src, T* dst, size_t stride_m = 0) {
if (stride_m == 0) {
stride_m = n;
}
auto batch_src = src; auto batch_src = src;
auto batch_dst = dst; auto batch_dst = dst;
constexpr size_t B = transpose_traits<T>::block_size; constexpr size_t B = transpose_traits<T>::block_size;
auto work_block = [m, n, &batch_src, &batch_dst]( auto work_block = [m, stride_m, &batch_src, &batch_dst](
const size_t i, const size_t j, const size_t h, const size_t i, const size_t j, const size_t h,
const size_t w) { const size_t w) {
auto src = batch_src + i * n + j, dst = batch_dst + j * m + i; auto src = batch_src + i * stride_m + j, dst = batch_dst + j * m + i;
MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) { MIDOUT_BEGIN(transpose_fallback, midout_iv(0)) {
if (h == B && w == B) { if (h == B && w == B) {
transpose_block(src, dst, n, m); transpose_block(src, dst, stride_m, m);
} else { } else {
transpose_block(src, dst, n, m, h, w); transpose_block(src, dst, stride_m, m, h, w);
} }
} }
MIDOUT_END(); MIDOUT_END();
...@@ -141,7 +146,7 @@ void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) { ...@@ -141,7 +146,7 @@ void transpose(size_t batch, size_t m, size_t n, T* src, T* dst) {
if (i < m) { if (i < m) {
work_row(i, m - i); work_row(i, m - i);
} }
batch_src += m * n; batch_src += m * stride_m;
batch_dst += m * n; batch_dst += m * n;
} }
} }
......
...@@ -48,10 +48,12 @@ void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) { ...@@ -48,10 +48,12 @@ void memcpy_noncont2cont(void* cont, void* non_cont, size_t size) {
} }
template <typename T> template <typename T>
void call_transpose(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { void call_transpose(
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst,
size_t stride_m) {
megdnn_assert(ch == 1); megdnn_assert(ch == 1);
relayout::transpose_fallback::transpose<T>( relayout::transpose_fallback::transpose<T>(
batch, m, n, static_cast<T*>(src), static_cast<T*>(dst)); batch, m, n, static_cast<T*>(src), static_cast<T*>(dst), stride_m);
} }
//! one operand contiguous, and the other non-contiguous //! one operand contiguous, and the other non-contiguous
...@@ -186,7 +188,10 @@ void transpose_cv_row( ...@@ -186,7 +188,10 @@ void transpose_cv_row(
} }
template <typename ctype> template <typename ctype>
void transpose_cv(size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst) { void transpose_cv(
size_t batch, size_t m, size_t n, size_t ch, void* src, void* dst,
size_t stride_m) {
megdnn_assert(stride_m == 0);
constexpr size_t B = BLOCK_SIZE; constexpr size_t B = BLOCK_SIZE;
auto batch_src = static_cast<ctype*>(src); auto batch_src = static_cast<ctype*>(src);
auto batch_dst = static_cast<ctype*>(dst); auto batch_dst = static_cast<ctype*>(dst);
...@@ -237,7 +242,7 @@ void RelayoutForwardImpl::exec( ...@@ -237,7 +242,7 @@ void RelayoutForwardImpl::exec(
} }
relayout::TransposeParam trans_param; relayout::TransposeParam trans_param;
bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true);
exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); exec_after_preprocess(src, dst, trans ? &trans_param : nullptr);
} }
...@@ -245,7 +250,7 @@ void RelayoutForwardImpl::exec_after_preprocess( ...@@ -245,7 +250,7 @@ void RelayoutForwardImpl::exec_after_preprocess(
const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) { const TensorND& src, const TensorND& dst, relayout::TransposeParam* transpose) {
if (transpose) { if (transpose) {
auto dsize = src.layout.dtype.size() * transpose->c; auto dsize = src.layout.dtype.size() * transpose->c;
void (*kptr)(size_t, size_t, size_t, size_t, void*, void*) = nullptr; void (*kptr)(size_t, size_t, size_t, size_t, void*, void*, size_t) = nullptr;
auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr), auto src_addr = reinterpret_cast<uintptr_t>(src.raw_ptr),
dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr); dst_addr = reinterpret_cast<uintptr_t>(dst.raw_ptr);
if (dsize == 1) { if (dsize == 1) {
...@@ -293,7 +298,9 @@ void RelayoutForwardImpl::exec_after_preprocess( ...@@ -293,7 +298,9 @@ void RelayoutForwardImpl::exec_after_preprocess(
if (kptr) { if (kptr) {
auto kern = [t = *transpose, sptr = src.raw_ptr, dptr = dst.raw_ptr, auto kern = [t = *transpose, sptr = src.raw_ptr, dptr = dst.raw_ptr,
kptr]() { kptr(t.batch, t.m, t.n, t.c, sptr, dptr); }; kptr]() {
kptr(t.batch, t.m, t.n, t.c, sptr, dptr, t.stride_m);
};
static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern); static_cast<naive::HandleImpl*>(handle())->dispatch_kern(kern);
return; return;
} else { } else {
......
/**
* \file dnn/test/aarch64/fixture.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/aarch64/fixture.h"
#include "test/common/memory_manager.h"
#include "test/common/random_state.h"
#include "test/common/utils.h"
namespace megdnn {
namespace test {
Handle* AARCH64::fallback_handle() {
if (!m_fallback_handle) {
m_fallback_handle = create_cpu_handle(1);
}
return m_fallback_handle.get();
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -19,7 +19,13 @@ ...@@ -19,7 +19,13 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {
class AARCH64 : public ARM_COMMON {}; class AARCH64 : public ARM_COMMON {
public:
Handle* fallback_handle();
private:
std::unique_ptr<Handle> m_handle, m_fallback_handle;
};
class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {}; class AARCH64_MULTI_THREADS : public ARM_COMMON_MULTI_THREADS {};
......
/**
* \file dnn/test/aarch64/relayout.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "test/aarch64/fixture.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
#include "test/common/relayout.h"
#include "test/common/rng.h"
namespace megdnn {
namespace test {
namespace {
template <typename tag>
class AARCH64_RELAYOUT : public AARCH64 {};
TYPED_TEST_CASE(AARCH64_RELAYOUT, relayout::test_types);
TYPED_TEST(AARCH64_RELAYOUT, run) {
relayout::run_test<TypeParam>(this->handle());
}
} // namespace
TEST_F(AARCH64, Relayout) {
Checker<Relayout> checker(handle());
std::vector<::megdnn::DType> dtype_vec;
dtype_vec.push_back(dtype::Float32());
dtype_vec.push_back(dtype::Int16());
dtype_vec.push_back(dtype::Uint16());
dtype_vec.push_back(dtype::Int8());
for (auto dtype : dtype_vec) {
TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype);
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype);
checker.execl({src, dst});
}
}
TEST_F(AARCH64, RelayoutBig) {
Checker<Relayout> checker(handle());
ConsecutiveRNG rng;
checker.set_rng(0, &rng);
int m = 512;
int n = 512;
TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype::Float32());
TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype::Float32());
checker.execl({src, dst});
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(AARCH64, BENCHMARK_Relayout) {
constexpr size_t WARM_RUNS = 100;
constexpr size_t RUNS = 600;
auto dtype = dtype::Float32();
Benchmarker<Relayout> benchmarker_relayout(handle());
Benchmarker<Relayout> benchmarker_fbk_relayout(fallback_handle());
benchmarker_relayout.set_times(WARM_RUNS);
benchmarker_fbk_relayout.set_times(WARM_RUNS);
int m = 512;
int n = 512;
TensorLayout src({(size_t)m, (size_t)n}, {1, n}, dtype);
TensorLayout dst({(size_t)m, (size_t)n}, {n, 1}, dtype);
TensorLayoutArray tensor_case;
tensor_case.push_back(src);
tensor_case.push_back(dst);
benchmarker_relayout.exec(tensor_case);
benchmarker_fbk_relayout.exec(tensor_case);
benchmarker_relayout.set_times(RUNS);
benchmarker_fbk_relayout.set_times(RUNS);
auto used = benchmarker_relayout.exec(tensor_case) / RUNS;
auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS;
float bw = 2.f * m * n * 1e-6 / used * dtype.size();
float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size();
printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n",
src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw);
}
TEST_F(AARCH64, BENCHMARK_Relayout_2) {
constexpr size_t WARM_RUNS = 100;
constexpr size_t RUNS = 600;
auto dtype = dtype::Float32();
Benchmarker<Relayout> benchmarker_relayout(handle());
Benchmarker<Relayout> benchmarker_fbk_relayout(fallback_handle());
benchmarker_relayout.set_times(WARM_RUNS);
benchmarker_fbk_relayout.set_times(WARM_RUNS);
int m = 54;
int n = 28762;
TensorLayout src({1, 54, 112, 256}, {54, 1, 16384, 64}, dtype);
TensorLayout dst({1, 54, 112, 256}, {1548288, 28672, 256, 1}, dtype);
TensorLayoutArray tensor_case;
tensor_case.push_back(src);
tensor_case.push_back(dst);
benchmarker_relayout.exec(tensor_case);
benchmarker_fbk_relayout.exec(tensor_case);
benchmarker_relayout.set_times(RUNS);
benchmarker_fbk_relayout.set_times(RUNS);
auto used = benchmarker_relayout.exec(tensor_case) / RUNS;
auto fbk_used = benchmarker_fbk_relayout.exec(tensor_case) / RUNS;
float bw = 2.f * m * n * 1e-6 / used * dtype.size();
float fbk_bw = 2.f * m * n * 1e-6 / fbk_used * dtype.size();
printf("run: %s -> %s , %f GB/s, fbk %f GB/s, speedup %f\n",
src.to_string().c_str(), dst.to_string().c_str(), bw, fbk_bw, bw / fbk_bw);
}
#endif
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -180,11 +180,11 @@ TEST(RELAYOUT, TRANSPOSE_DET) { ...@@ -180,11 +180,11 @@ TEST(RELAYOUT, TRANSPOSE_DET) {
ASSERT_EQ(p_get.c, p.c); ASSERT_EQ(p_get.c, p.c);
} }
}; };
run({2, 3}, {1, 0}, true, {1, 2, 3, 1}); run({2, 3}, {1, 0}, true, {1, 2, 3, 1, 0});
run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5}); run({2, 3, 5}, {1, 0, 2}, true, {1, 2, 3, 5, 0});
run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1}); run({2, 3, 5}, {0, 2, 1}, true, {2, 3, 5, 1, 0});
run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5}); run({3, 2, 3, 5}, {0, 2, 1, 3}, true, {3, 2, 3, 5, 0});
run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1}); run({3, 2, 3, 5}, {0, 1, 3, 2}, true, {6, 3, 5, 1, 0});
run({2, 3, 5}, {2, 1, 0}, false); run({2, 3, 5}, {2, 1, 0}, false);
run({3, 2, 3, 5}, {3, 2, 1, 0}, false); run({3, 2, 3, 5}, {3, 2, 1, 0}, false);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册