提交 2ae9fdef 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm): add arm common nchw44 avg pooling

GitOrigin-RevId: 25eab33e14ef4000480acea17a49e4360d42bdd7
上级 0293d58a
此差异已折叠。
......@@ -83,34 +83,34 @@ public:
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter3MaxStridexNCHW44 final : public AlgoBase {
class PoolingImpl::AlgoFilter3ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"; }
const char* name() const override { return "ARM_POOLING_FILTER3_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter2MaxStridexNCHW44 final : public AlgoBase {
class PoolingImpl::AlgoFilter2ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER2_MAX_STRIDEX_NCHW44"; }
const char* name() const override { return "ARM_POOLING_FILTER2_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter4MaxStridexNCHW44 final : public AlgoBase {
class PoolingImpl::AlgoFilter4ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER4_MAX_STRIDEX_NCHW44"; }
const char* name() const override { return "ARM_POOLING_FILTER4_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter5MaxStridexNCHW44 final : public AlgoBase {
class PoolingImpl::AlgoFilter5ModexStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER5_MAX_STRIDEX_NCHW44"; }
const char* name() const override { return "ARM_POOLING_FILTER5_MODEX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
......@@ -122,7 +122,7 @@ WorkspaceBundle get_bundle_nchw44(
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW,
size_t& IH2, size_t& IW2, size_t PH, size_t PW,
const WorkspaceBundle& ws);
const WorkspaceBundle& ws, bool is_max_mode);
} // namespace arm_common
} // namespace megdnn
......
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp
* \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -24,7 +24,7 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
......@@ -70,7 +70,7 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
......@@ -120,6 +120,206 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void do_avg_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 4;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int8x16_t src0123, src1234;
int16x8_t src01, src23, src12, src34;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src0123 = vld1q_s8(sptr##i); \
src1234 = vld1q_s8(sptr##i + 4); \
src01 = vmovl_s8(vget_low_s8(src0123)); \
src23 = vmovl_s8(vget_high_s8(src0123)); \
src12 = vmovl_s8(vget_low_s8(src1234)); \
src34 = vmovl_s8(vget_high_s8(src1234)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34);
UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 16;
sptr1 += 16;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src101 = vld1_s8(sptr1);
int16x8_t src00 = vmovl_s8(src001);
int16x8_t src10 = vmovl_s8(src101);
int16x8_t max_tmp = vaddq_s16(src00, src10);
#define do_acc(i) \
int16_t sum##i = \
vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
sptr0 += 4;
sptr1 += 4;
dptr += 4;
}
}
}
void do_avg_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 4;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int32x4x2_t src_tmp;
int8x16_t src00, src04;
int32x4_t src0246, src1357;
int16x8_t src02, src13, src46, src57;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57);
UNROLL_CALL_NOWRAPPER(2, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 32;
sptr1 += 32;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src101 = vld1_s8(sptr1);
int16x8_t src00 = vmovl_s8(src001);
int16x8_t src10 = vmovl_s8(src101);
int16x8_t max_tmp = vaddq_s16(src00, src10);
#define do_acc(i) \
int16_t sum##i = \
vgetq_lane_s16(max_tmp, i) + vgetq_lane_s16(max_tmp, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef do_avg
#undef do_acc
#undef store
sptr0 += 8;
sptr1 += 8;
dptr += 4;
}
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h
* \file dnn/src/arm_common/pooling/do_pooling_2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -15,16 +15,15 @@
namespace megdnn {
namespace arm_common {
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws);
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws);
#define KERN(mode, stride, ctype) \
void do_##mode##_pooling_2x2_##stride##_##ctype##_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN(max, stride1, int8)
KERN(max, stride2, int8)
KERN(avg, stride1, int8)
KERN(avg, stride2, int8)
#undef KERN
} // namespace arm_common
} // namespace megdnn
......
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp
* \file dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -24,7 +24,7 @@ void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
......@@ -99,7 +99,7 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
......@@ -190,6 +190,241 @@ void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void do_avg_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 9;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int8x16_t src0123, src1234, src2345;
int16x8_t src01, src23, src12, src34, src45;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src0123 = vld1q_s8(sptr##i); \
src1234 = vld1q_s8(sptr##i + 4); \
src2345 = vld1q_s8(sptr##i + 8); \
src01 = vmovl_s8(vget_low_s8(src0123)); \
src23 = vmovl_s8(vget_high_s8(src0123)); \
src12 = vmovl_s8(vget_low_s8(src1234)); \
src34 = vmovl_s8(vget_high_s8(src1234)); \
src45 = vmovl_s8(vget_high_s8(src2345)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum01 = vaddq_s16(sum01, src23); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34); \
sum23 = vaddq_s16(sum23, src45);
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 16;
sptr1 += 16;
sptr2 += 16;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int8x8_t src001, src012;
int16x8_t src01, src12, sum01, sum02;
sum01 = vdupq_n_s16(0);
sum02 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src012 = vld1_s8(sptr##i + 4); \
src01 = vmovl_s8(src001); \
src12 = vmovl_s8(src012); \
sum01 = vaddq_s16(sum01, src01); \
sum02 = vaddq_s16(sum02, src12);
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0 += 4;
sptr1 += 4;
sptr2 += 4;
dptr += 4;
}
}
}
void do_avg_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 9;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int32x4x2_t src_tmp;
int8x16_t src00, src04;
int32x4_t src0246, src1357, src2468, src08;
int16x8_t src02, src46, src13, src57, src24, src68;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, src08, 1); \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum01 = vaddq_s16(sum01, src24); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57); \
sum23 = vaddq_s16(sum23, src68);
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 32;
sptr1 += 32;
sptr2 += 32;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int8x8_t src001, src012;
int16x8_t src01, src12, sum01, sum02;
sum01 = vdupq_n_s16(0);
sum02 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src012 = vld1_s8(sptr##i + 4); \
src01 = vmovl_s8(src001); \
src12 = vmovl_s8(src012); \
sum01 = vaddq_s16(sum01, src01); \
sum02 = vaddq_s16(sum02, src12);
UNROLL_CALL_NOWRAPPER(3, CACULATE_ROW)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
sptr0 += 8;
sptr1 += 8;
sptr2 += 8;
dptr += 4;
}
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h
* \file dnn/src/arm_common/pooling/do_pooling_3x3_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -15,16 +15,15 @@
namespace megdnn {
namespace arm_common {
void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws);
void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t OH,
size_t OW, size_t PH, size_t PW,
const WorkspaceBundle& ws);
#define KERN(mode, stride, ctype) \
void do_##mode##_pooling_3x3_##stride##_##ctype##_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN(max, stride1, int8)
KERN(max, stride2, int8)
KERN(avg, stride1, int8)
KERN(avg, stride2, int8)
#undef KERN
} // namespace arm_common
} // namespace megdnn
......
......@@ -24,7 +24,7 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
......@@ -99,7 +99,7 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
......@@ -171,6 +171,252 @@ void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void do_avg_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 16;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int16x8_t src01, src23, src12, src34, src45, src56;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src01 = vmovl_s8(vld1_s8(sptr##i)); \
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \
src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \
src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \
src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \
src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum01 = vaddq_s16(sum01, src23); \
sum01 = vaddq_s16(sum01, src34); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34); \
sum23 = vaddq_s16(sum23, src45); \
sum23 = vaddq_s16(sum23, src56);
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 16;
sptr1 += 16;
sptr2 += 16;
sptr3 += 16;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int16x8_t src01, src23, sum01;
sum01 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src01 = vmovl_s8(vld1_s8(sptr##i)); \
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23);
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0 += 4;
sptr1 += 4;
sptr2 += 4;
sptr3 += 4;
dptr += 4;
}
}
}
void do_avg_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 16;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int32x4x2_t src_tmp;
int8x16_t src00, src04;
int16x8_t src02, src13, src57, src24, src68, src35, src79, src46;
int32x4_t src08, src09, src0246, src1357, src2468, src3579;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, src08, 1); \
src3579 = vextq_s32(src1357, src09, 1); \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \
src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \
src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum01 = vaddq_s16(sum01, src24); \
sum01 = vaddq_s16(sum01, src35); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57); \
sum23 = vaddq_s16(sum23, src68); \
sum23 = vaddq_s16(sum23, src79);
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 32;
sptr1 += 32;
sptr2 += 32;
sptr3 += 32;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int8x8_t src001, src023;
int16x8_t src01, src23, sum01;
sum01 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src023 = vld1_s8(sptr##i + 8); \
src01 = vmovl_s8(src001); \
src23 = vmovl_s8(src023); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23);
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0 += 8;
sptr1 += 8;
sptr2 += 8;
sptr3 += 8;
dptr += 4;
}
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h
* \file dnn/src/arm_common/pooling/do_pooling_4x4_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -15,15 +15,16 @@
namespace megdnn {
namespace arm_common {
#define KERN(strdie) \
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \
#define KERN(mode, stride, ctype) \
void do_##mode##_pooling_4x4_##stride##_##ctype##_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN(stride1)
KERN(stride2)
KERN(max, stride1, int8)
KERN(max, stride2, int8)
KERN(avg, stride1, int8)
KERN(avg, stride2, int8)
#undef KERN
} // namespace arm_common
} // namespace megdnn
......
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_5x5_nchw44.cpp
* \file dnn/src/arm_common/pooling/do_pooling_5x5_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -24,7 +24,7 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
......@@ -118,7 +118,7 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, true);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
......@@ -213,6 +213,284 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
}
}
void do_avg_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 25;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4;
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int16x8_t src01, src23, src12, src34, src45, src56, src67;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src01 = vmovl_s8(vld1_s8(sptr##i)); \
src23 = vmovl_s8(vld1_s8(sptr##i + 8)); \
src12 = vmovl_s8(vld1_s8(sptr##i + 4)); \
src34 = vmovl_s8(vld1_s8(sptr##i + 12)); \
src45 = vmovl_s8(vld1_s8(sptr##i + 16)); \
src56 = vmovl_s8(vld1_s8(sptr##i + 20)); \
src67 = vmovl_s8(vld1_s8(sptr##i + 24)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src12); \
sum01 = vaddq_s16(sum01, src23); \
sum01 = vaddq_s16(sum01, src34); \
sum01 = vaddq_s16(sum01, src45); \
sum23 = vaddq_s16(sum23, src23); \
sum23 = vaddq_s16(sum23, src34); \
sum23 = vaddq_s16(sum23, src45); \
sum23 = vaddq_s16(sum23, src56); \
sum23 = vaddq_s16(sum23, src67);
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 16;
sptr1 += 16;
sptr2 += 16;
sptr3 += 16;
sptr4 += 16;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int32x2_t src004;
int8x8_t src001, src023;
int16x8_t src01, src23, src04, sum01, sum02;
sum01 = vdupq_n_s16(0);
sum02 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src023 = vld1_s8(sptr##i + 8); \
src004 = vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); \
src01 = vmovl_s8(src001); \
src23 = vmovl_s8(src023); \
src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23); \
sum02 = vaddq_s16(sum02, src04);
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0 += 4;
sptr1 += 4;
sptr2 += 4;
sptr3 += 4;
sptr4 += 4;
dptr += 4;
}
}
}
void do_avg_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int16_t filter_size = 25;
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws, false);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4;
const int8_t* __restrict sptr4 = sptr + (ih + 4) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int32x4x2_t src_tmp;
int8x16_t src00, src04;
int16x8_t src02, src13, src57, src24, src68, src35, src79, src46,
src810;
int32x4_t src08, src09, src10, src0246, src1357, src2468, src3579,
src46810;
int16x8_t sum01 = vdupq_n_s16(0);
int16x8_t sum23 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \
src10 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 10)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, src08, 1); \
src3579 = vextq_s32(src1357, src09, 1); \
src46810 = vextq_s32(src2468, src10, 1); \
src02 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src0246))); \
src46 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src0246))); \
src13 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src1357))); \
src57 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src1357))); \
src24 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src2468))); \
src68 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src2468))); \
src35 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src3579))); \
src79 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src3579))); \
src46 = vmovl_s8(vget_low_s8(vreinterpretq_s8_s32(src46810))); \
src810 = vmovl_s8(vget_high_s8(vreinterpretq_s8_s32(src46810))); \
sum01 = vaddq_s16(sum01, src02); \
sum01 = vaddq_s16(sum01, src13); \
sum01 = vaddq_s16(sum01, src24); \
sum01 = vaddq_s16(sum01, src35); \
sum01 = vaddq_s16(sum01, src46); \
sum23 = vaddq_s16(sum23, src46); \
sum23 = vaddq_s16(sum23, src57); \
sum23 = vaddq_s16(sum23, src68); \
sum23 = vaddq_s16(sum23, src79); \
sum23 = vaddq_s16(sum23, src810);
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW)
#define sum_define(i) int16_t sum##i;
UNROLL_CALL_NOWRAPPER(8, sum_define)
#define sum01_avg(i) \
sum##i = vgetq_lane_s16(sum01, i) > 0 \
? (vgetq_lane_s16(sum01, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum01, i) - filter_size / 2) / \
filter_size;
#define sum23_avg(i) \
sum##i = vgetq_lane_s16(sum23, i) > 0 \
? (vgetq_lane_s16(sum23, i) + filter_size / 2) / \
filter_size \
: (vgetq_lane_s16(sum23, i) - filter_size / 2) / \
filter_size;
#define store_sum01(i) *(dptr + i) = static_cast<int8_t>(sum##i);
#define store_sum23(i) *(dptr + i + 8) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(8, sum01_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum01)
UNROLL_CALL_NOWRAPPER(8, sum23_avg)
UNROLL_CALL_NOWRAPPER(8, store_sum23)
sptr0 += 32;
sptr1 += 32;
sptr2 += 32;
sptr3 += 32;
sptr4 += 32;
dptr += 16;
#undef store_sum01
#undef store_sum23
#undef sum01_avg
#undef sum23_avg
#undef sum_define
#undef CACULATE_ROW
}
for (; ow < OW; ++ow) {
int32x2_t src004;
int8x8_t src001, src023;
int16x8_t src01, src23, src04, sum01, sum02;
sum01 = vdupq_n_s16(0);
sum02 = vdupq_n_s16(0);
#define CACULATE_ROW(i) \
src001 = vld1_s8(sptr##i); \
src023 = vld1_s8(sptr##i + 8); \
src004 = vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4)); \
src01 = vmovl_s8(src001); \
src23 = vmovl_s8(src023); \
src04 = vmovl_s8(vreinterpret_s8_s32(src004)); \
sum01 = vaddq_s16(sum01, src01); \
sum01 = vaddq_s16(sum01, src23); \
sum02 = vaddq_s16(sum02, src04);
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW)
#define do_acc(i) \
int16_t sum##i = vgetq_lane_s16(sum01, i) + vgetq_lane_s16(sum01, i + 4) + \
vgetq_lane_s16(sum02, i + 4);
#define do_avg(i) \
sum##i = sum##i > 0 ? (sum##i + filter_size / 2) / filter_size \
: (sum##i - filter_size / 2) / filter_size;
#define store(i) *(dptr + i) = static_cast<int8_t>(sum##i);
UNROLL_CALL_NOWRAPPER(4, do_acc)
UNROLL_CALL_NOWRAPPER(4, do_avg)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef do_avg
#undef do_acc
#undef CACULATE_ROW
sptr0 += 8;
sptr1 += 8;
sptr2 += 8;
sptr3 += 8;
sptr4 += 8;
dptr += 4;
}
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_4x4_nchw44.h
* \file dnn/src/arm_common/pooling/do__pooling_5x5_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
......@@ -15,15 +15,16 @@
namespace megdnn {
namespace arm_common {
#define KERN(strdie) \
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \
#define KERN(mode, stride, ctype) \
void do_##mode##_pooling_5x5_##stride##_##ctype##_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN(stride1)
KERN(stride2)
KERN(max, stride1, int8)
KERN(max, stride2, int8)
KERN(avg, stride1, int8)
KERN(avg, stride2, int8)
#undef KERN
} // namespace arm_common
} // namespace megdnn
......
......@@ -25,10 +25,10 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoFilter5MaxStride2 algo_filter5_max_stride2;
AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2;
AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2;
AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4;
AlgoFilter3MaxStridexNCHW44 algo_filter3_max_stridex_nchw4;
AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4;
AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4;
AlgoFilter2ModexStridexNCHW44 algo_filter2_modex_stridex_nchw4;
AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4;
AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
public:
AlgoPack() {
......@@ -40,10 +40,10 @@ public:
all_algos.emplace_back(&algo_filter5_max_stride2);
all_algos.emplace_back(&algo_int8_filter2_max_stride2);
all_algos.emplace_back(&algo_int8_filter3_max_stride2);
all_algos.emplace_back(&algo_filter3_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter4_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter5_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter3_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
}
SmallVector<AlgoBase*> all_algos;
};
......
......@@ -83,10 +83,10 @@ private:
class AlgoFilter5MaxStride2;
class AlgoInt8Filter2MaxStride2;
class AlgoInt8Filter3MaxStride2;
class AlgoFilter2MaxStridexNCHW44;
class AlgoFilter3MaxStridexNCHW44;
class AlgoFilter4MaxStridexNCHW44;
class AlgoFilter5MaxStridexNCHW44;
class AlgoFilter2ModexStridexNCHW44;
class AlgoFilter3ModexStridexNCHW44;
class AlgoFilter4ModexStridexNCHW44;
class AlgoFilter5ModexStridexNCHW44;
class AlgoPack;
};
} // namespace arm_common
......
......@@ -10,6 +10,7 @@
*/
#include <vector>
#include "megdnn/dtype.h"
#include "megdnn/opr_param_defs.h"
#include "test/arm_common/fixture.h"
#include "test/common/pooling.h"
......@@ -56,13 +57,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) {
}
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W3x3_NCHW44)
{
// clang-format off
for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto mode: {param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
......@@ -71,7 +73,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44)
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.mode = mode;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
......@@ -86,13 +88,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44)
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W2x2_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0, 1})
for (size_t pw: {0, 1})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
......@@ -101,7 +104,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44)
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.mode = mode;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
......@@ -115,13 +118,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44)
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W4x4_NCHW44)
{
// clang-format off
for (size_t ih: {4, 10, 18, 25, 30})
for (size_t iw: {4, 12, 17, 20, 25})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 4 && iw+2*pw >= 4)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
......@@ -130,7 +134,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44)
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.mode = mode;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
......@@ -143,13 +147,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44)
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_W5x5_NCHW44)
{
// clang-format off
for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0, 1, 2})
for (size_t pw: {0, 1, 2})
for(auto mode: {param::Pooling::Mode::MAX,param::Pooling::Mode::AVERAGE})
if (ih+2*ph >= 5 && iw+2*pw >= 5)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
......@@ -158,7 +163,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44)
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.mode = mode;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
......@@ -477,31 +482,37 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) {
std::vector<SmallVector<TensorShape>> shapes;
std::vector<std::vector<size_t>> filter_and_stride = {
{2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}};
for(auto filter:filter_and_stride){
shapes.push_back({{1, 32 * 4, 215, 215}, {}});
shapes.push_back({{1, 32 * 4, 128, 128}, {}});
shapes.push_back({{1, 16 * 4, 56, 56}, {}});
param.window_h = param.window_w = filter[0];
param.stride_h = param.stride_w = filter[1];
param.format = Param::Format::NCHW;
printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h,
param.window_h, param.stride_h, static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
dtype::QuantizedS8(1.1f));
shapes.clear();
shapes.push_back({{1, 32, 215, 215,4}, {}});
shapes.push_back({{1, 32, 128, 128,4}, {}});
shapes.push_back({{1, 16, 56, 56, 4}, {}});
param.format = Param::Format::NCHW44;
printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h,
param.window_w, param.stride_h, static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
dtype::QuantizedS8(1.1f));
shapes.clear();
}
for (auto mode :
{param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) {
for (auto filter : filter_and_stride) {
shapes.push_back({{1, 32 * 4, 215, 215}, {}});
shapes.push_back({{1, 32 * 4, 128, 128}, {}});
shapes.push_back({{1, 16 * 4, 56, 56}, {}});
param.mode = mode;
param.window_h = param.window_w = filter[0];
param.stride_h = param.stride_w = filter[1];
param.format = Param::Format::NCHW;
printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n",
param.window_h, param.window_h, param.stride_h,
static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}},
{1, {4}}, dtype::QuantizedS8(1.1f));
shapes.clear();
shapes.push_back({{1, 32, 215, 215, 4}, {}});
shapes.push_back({{1, 32, 128, 128, 4}, {}});
shapes.push_back({{1, 16, 56, 56, 4}, {}});
param.format = Param::Format::NCHW44;
printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n",
param.window_h, param.window_w, param.stride_h,
static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}},
{1, {4}}, dtype::QuantizedS8(1.1f));
shapes.clear();
}
}
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册