提交 df8931b6 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm): add padding support for nchw44 arm pooling and opt code

GitOrigin-RevId: f125004e1f271656f2c2646913aea6afdd112e15
上级 07dd6b6c
...@@ -11,14 +11,13 @@ ...@@ -11,14 +11,13 @@
*/ */
#include "src/arm_common/pooling/algo.h" #include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h" #include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h" #include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h" #include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
#include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
#include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"
#include "midout.h" #include "midout.h"
...@@ -57,6 +56,41 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) { ...@@ -57,6 +56,41 @@ WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param) {
return ws; return ws;
} }
WorkspaceBundle get_bundle_nchw44(
const PoolingImpl::PoolingKernSizeParam& param) {
megdnn_assert((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
(param.format == param::Pooling::Format::NCHW44));
auto IH = param.isz[0];
auto IW = param.isz[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
size_t padding_size = 0;
if ((PH != 0) || (PW != 0)) {
padding_size = (IW + 2 * PW) * (IH + 2 * PH) * 4 * sizeof(int8_t);
}
return WorkspaceBundle(nullptr, {padding_size});
}
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW,
size_t& IH2, size_t& IW2, size_t PH, size_t PW,
const WorkspaceBundle& ws) {
int8_t* sptr_base = nullptr;
bool need_pad = ((PH != 0) || (PW != 0)) ? true : false;
if (need_pad) {
IH2 = IH + 2 * PH;
IW2 = IW + 2 * PW;
sptr_base = static_cast<int8_t*>(ws.get(0));
memset(sptr_base, -128, sizeof(int8_t) * IH2 * IW2 * 4);
rep(ih, IH) {
std::memcpy(sptr_base + (ih + PH) * IW2 * 4 + PW * 4,
src + ih * IW * 4, sizeof(int8_t) * IW * 4);
}
} else {
IH2 = IH;
IW2 = IW;
}
return need_pad ? sptr_base : src;
}
bool PoolingImpl::AlgoFilterxModexStride1::usable( bool PoolingImpl::AlgoFilterxModexStride1::usable(
const PoolingKernSizeParam& param) const { const PoolingKernSizeParam& param) const {
auto SH = param.stride[0]; auto SH = param.stride[0];
...@@ -563,47 +597,50 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec( ...@@ -563,47 +597,50 @@ void PoolingImpl::AlgoInt8Filter3MaxStride2::exec(
MIDOUT_END(); MIDOUT_END();
} }
bool PoolingImpl::AlgoFilter3MaxStride2NCHW44::usable( bool PoolingImpl::AlgoFilter3MaxStridexNCHW44::usable(
const PoolingKernSizeParam& param) const { const PoolingKernSizeParam& param) const {
auto SH = param.stride[0]; auto SH = param.stride[0];
auto SW = param.stride[1]; auto SW = param.stride[1];
auto FH = param.filter[0]; auto FH = param.filter[0];
auto FW = param.filter[1]; auto FW = param.filter[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.format == Param::Format::NCHW44 && param.format == Param::Format::NCHW44 &&
param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 2 && param.mode == Mode::MAX && FH == 3 && FW == 3 && SW == SH &&
SW == 2 && PH == 0 && PW == 0; (SH == 1 || SW == 2);
return avaible; return avaible;
} }
void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( void PoolingImpl::AlgoFilter3MaxStridexNCHW44::exec(
const PoolingKernParam& param) const { const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1]; auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1]; auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic; auto N = param.n, C = param.ic;
auto PH = param.padding[0]; auto PH = param.padding[0];
auto PW = param.padding[1]; auto PW = param.padding[1];
auto SW = param.stride[0];
void* src_ptr = param.src_ptr; void* src_ptr = param.src_ptr;
void* dst_ptr = param.dst_ptr; void* dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id) \ #define DISPATCH_FUNC(type, func, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \ midout_iv(#type #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \ size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \ auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \ size_t n = index / C; \
size_t c = index % C; \ size_t c = index % C; \
do_max_pooling_3x3_s2x2_##func##_nchw44_NEON( \ do_max_pooling_3x3_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \ c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \ c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \ IH, IW, OH, OW, PH, PW, ws); \
}; \ }; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...@@ -611,61 +648,23 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec( ...@@ -611,61 +648,23 @@ void PoolingImpl::AlgoFilter3MaxStride2NCHW44::exec(
} \ } \
MIDOUT_END(); MIDOUT_END();
DISPATCH_FUNC(int8_t, int8, 9); #define DISPATCH_STRIDE(type, func) \
switch (SW) { \
#undef DISPATCH_FUNC case 1: { \
} DISPATCH_FUNC(type, func, 1); \
break; \
bool PoolingImpl::AlgoFilter3MaxStride1NCHW44::usable( } \
const PoolingKernSizeParam& param) const { case 2: { \
auto SH = param.stride[0]; DISPATCH_FUNC(type, func, 2); \
auto SW = param.stride[1]; break; \
auto FH = param.filter[0]; } \
auto FW = param.filter[1]; default: \
auto PH = param.padding[0]; megdnn_assert(0, "unsupport stride size"); \
auto PW = param.padding[1]; }
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.format == Param::Format::NCHW44 &&
param.mode == Mode::MAX && FH == 3 && FW == 3 && SH == 1 &&
SW == 1 && PH == 0 && PW == 0;
return avaible;
}
void PoolingImpl::AlgoFilter3MaxStride1NCHW44::exec(
const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
void* src_ptr = param.src_ptr;
void* dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_3x3_s1x1_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
run); \
} \
MIDOUT_END();
DISPATCH_FUNC(int8_t, int8, 10); DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC #undef DISPATCH_FUNC
} }
...@@ -675,13 +674,11 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable( ...@@ -675,13 +674,11 @@ bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable(
auto SW = param.stride[1]; auto SW = param.stride[1];
auto FH = param.filter[0]; auto FH = param.filter[0];
auto FW = param.filter[1]; auto FW = param.filter[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.format == Param::Format::NCHW44 && param.format == Param::Format::NCHW44 &&
param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW && param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW &&
(SW == 1 || SW == 2) && PH == 0 && PW == 0; (SW == 1 || SW == 2);
return avaible; return avaible;
} }
...@@ -697,12 +694,16 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( ...@@ -697,12 +694,16 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
void* src_ptr = param.src_ptr; void* src_ptr = param.src_ptr;
void* dst_ptr = param.dst_ptr; void* dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ #define DISPATCH_FUNC(type, func, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \ midout_iv(#func #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \ size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \ auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \ size_t n = index / C; \
size_t c = index % C; \ size_t c = index % C; \
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
...@@ -710,7 +711,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( ...@@ -710,7 +711,7 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
c * IH * IW * 4, \ c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \ c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \ IH, IW, OH, OW, PH, PW, ws); \
}; \ }; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...@@ -718,21 +719,21 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec( ...@@ -718,21 +719,21 @@ void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
} \ } \
MIDOUT_END(); MIDOUT_END();
#define DISPATCH_STRIDE(type, func, midout_type_id) \ #define DISPATCH_STRIDE(type, func) \
switch (SW) { \ switch (SW) { \
case 1: { \ case 1: { \
DISPATCH_FUNC(type, func, midout_type_id, 1); \ DISPATCH_FUNC(type, func, 1); \
break; \ break; \
} \ } \
case 2: { \ case 2: { \
DISPATCH_FUNC(type, func, midout_type_id, 2); \ DISPATCH_FUNC(type, func, 2); \
break; \ break; \
} \ } \
default: \ default: \
megdnn_assert(0, "unsupport stride size"); \ megdnn_assert(0, "unsupport stride size"); \
} }
DISPATCH_STRIDE(int8_t, int8, 10); DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE #undef DISPATCH_STRIDE
#undef DISPATCH_FUNC #undef DISPATCH_FUNC
...@@ -744,13 +745,11 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable( ...@@ -744,13 +745,11 @@ bool PoolingImpl::AlgoFilter4MaxStridexNCHW44::usable(
auto SW = param.stride[1]; auto SW = param.stride[1];
auto FH = param.filter[0]; auto FH = param.filter[0];
auto FW = param.filter[1]; auto FW = param.filter[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.format == Param::Format::NCHW44 && param.format == Param::Format::NCHW44 &&
param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW && param.mode == Mode::MAX && FH == 4 && FW == 4 && SH == SW &&
(SW == 1 || SW == 2) && PH == 0 && PW == 0; (SW == 1 || SW == 2);
return avaible; return avaible;
} }
...@@ -766,12 +765,16 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( ...@@ -766,12 +765,16 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
void* src_ptr = param.src_ptr; void* src_ptr = param.src_ptr;
void* dst_ptr = param.dst_ptr; void* dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ #define DISPATCH_FUNC(type, func, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \ midout_iv(#func #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \ size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \ auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \ size_t n = index / C; \
size_t c = index % C; \ size_t c = index % C; \
do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ do_max_pooling_4x4_stride##i##_##func##_nchw44_NEON( \
...@@ -779,7 +782,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( ...@@ -779,7 +782,7 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
c * IH * IW * 4, \ c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \ c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \ IH, IW, OH, OW, PH, PW, ws); \
}; \ }; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...@@ -787,21 +790,21 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec( ...@@ -787,21 +790,21 @@ void PoolingImpl::AlgoFilter4MaxStridexNCHW44::exec(
} \ } \
MIDOUT_END(); MIDOUT_END();
#define DISPATCH_STRIDE(type, func, midout_type_id) \ #define DISPATCH_STRIDE(type, func) \
switch (SW) { \ switch (SW) { \
case 1: { \ case 1: { \
DISPATCH_FUNC(type, func, midout_type_id, 1); \ DISPATCH_FUNC(type, func, 1); \
break; \ break; \
} \ } \
case 2: { \ case 2: { \
DISPATCH_FUNC(type, func, midout_type_id, 2); \ DISPATCH_FUNC(type, func, 2); \
break; \ break; \
} \ } \
default: \ default: \
megdnn_assert(0, "unsupport stride size"); \ megdnn_assert(0, "unsupport stride size"); \
} }
DISPATCH_STRIDE(int8_t, int8, 11); DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE #undef DISPATCH_STRIDE
#undef DISPATCH_FUNC #undef DISPATCH_FUNC
...@@ -813,13 +816,11 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable( ...@@ -813,13 +816,11 @@ bool PoolingImpl::AlgoFilter5MaxStridexNCHW44::usable(
auto SW = param.stride[1]; auto SW = param.stride[1];
auto FH = param.filter[0]; auto FH = param.filter[0];
auto FW = param.filter[1]; auto FW = param.filter[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 && bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.format == Param::Format::NCHW44 && param.format == Param::Format::NCHW44 &&
param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW && param.mode == Mode::MAX && FH == 5 && FW == 5 && SH == SW &&
(SW == 1 || SW == 2) && PH == 0 && PW == 0; (SW == 1 || SW == 2);
return avaible; return avaible;
} }
...@@ -835,12 +836,16 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( ...@@ -835,12 +836,16 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
void* src_ptr = param.src_ptr; void* src_ptr = param.src_ptr;
void* dst_ptr = param.dst_ptr; void* dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id, i) \ #define DISPATCH_FUNC(type, func, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \ MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \ midout_iv(#func #i##_hash)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \ WorkspaceBundle wbundle = get_bundle_nchw44(param); \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \
wbundle = wbundle, \
workspace_ptr = param.workspace<dt_byte>()]( \
size_t index, size_t thread_id) { \ size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \ auto ws = wbundle; \
ws.set(workspace_ptr + ws.total_size_in_bytes() * thread_id); \
size_t n = index / C; \ size_t n = index / C; \
size_t c = index % C; \ size_t c = index % C; \
do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ do_max_pooling_5x5_stride##i##_##func##_nchw44_NEON( \
...@@ -848,7 +853,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( ...@@ -848,7 +853,7 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
c * IH * IW * 4, \ c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \ c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \ IH, IW, OH, OW, PH, PW, ws); \
}; \ }; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \ static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
...@@ -856,21 +861,21 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec( ...@@ -856,21 +861,21 @@ void PoolingImpl::AlgoFilter5MaxStridexNCHW44::exec(
} \ } \
MIDOUT_END(); MIDOUT_END();
#define DISPATCH_STRIDE(type, func, midout_type_id) \ #define DISPATCH_STRIDE(type, func) \
switch (SW) { \ switch (SW) { \
case 1: { \ case 1: { \
DISPATCH_FUNC(type, func, midout_type_id, 1); \ DISPATCH_FUNC(type, func, 1); \
break; \ break; \
} \ } \
case 2: { \ case 2: { \
DISPATCH_FUNC(type, func, midout_type_id, 2); \ DISPATCH_FUNC(type, func, 2); \
break; \ break; \
} \ } \
default: \ default: \
megdnn_assert(0, "unsupport stride size"); \ megdnn_assert(0, "unsupport stride size"); \
} }
DISPATCH_STRIDE(int8_t, int8, 12); DISPATCH_STRIDE(int8_t, int8);
#undef DISPATCH_STRIDE #undef DISPATCH_STRIDE
#undef DISPATCH_FUNC #undef DISPATCH_FUNC
......
...@@ -83,18 +83,10 @@ public: ...@@ -83,18 +83,10 @@ public:
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
}; };
class PoolingImpl::AlgoFilter3MaxStride2NCHW44 final : public AlgoBase { class PoolingImpl::AlgoFilter3MaxStridexNCHW44 final : public AlgoBase {
public: public:
bool is_reproducible() const override { return true; } bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE2_NCHW44"; } const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter3MaxStride1NCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER3_MAX_STRIDE1_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override; bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override; void exec(const PoolingKernParam& param) const override;
}; };
...@@ -125,6 +117,12 @@ public: ...@@ -125,6 +117,12 @@ public:
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param); WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param);
WorkspaceBundle get_bundle_nchw44(
const PoolingImpl::PoolingKernSizeParam& param);
const int8_t* handle_padding(const int8_t* src, size_t IH, size_t IW,
size_t& IH2, size_t& IW2, size_t PH, size_t PW,
const WorkspaceBundle& ws);
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
namespace megdnn {
namespace arm_common {
void do_max_pooling_3x3_s1x1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t OH,
size_t OW, size_t PH, size_t PW) {
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int8x16_t src0123 = vld1q_s8(sptr0);
int8x16_t src1234 = vld1q_s8(sptr0 + 4);
int8x16_t src2345 = vld1q_s8(sptr0 + 8);
int8x16_t max0 = vmaxq_s8(src0123, src1234);
max0 = vmaxq_s8(max0, src2345);
src0123 = vld1q_s8(sptr1);
src1234 = vld1q_s8(sptr1 + 4);
src2345 = vld1q_s8(sptr1 + 8);
int8x16_t max1 = vmaxq_s8(src0123, src1234);
max1 = vmaxq_s8(max1, src2345);
src0123 = vld1q_s8(sptr2);
src1234 = vld1q_s8(sptr2 + 4);
src2345 = vld1q_s8(sptr2 + 8);
int8x16_t max2 = vmaxq_s8(src0123, src1234);
max2 = vmaxq_s8(max2, src2345);
int8x16_t max_out = vmaxq_s8(max0, max1);
max_out = vmaxq_s8(max_out, max2);
vst1q_s8(dptr, max_out);
sptr0 += 16;
sptr1 += 16;
sptr2 += 16;
dptr += 16;
}
for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src012 = vld1_s8(sptr0 + 4);
int8x8_t src101 = vld1_s8(sptr1);
int8x8_t src112 = vld1_s8(sptr1 + 4);
int8x8_t src201 = vld1_s8(sptr2);
int8x8_t src212 = vld1_s8(sptr2 + 4);
int8x8_t max01_tmp = vmax_s8(src001, src101);
max01_tmp = vmax_s8(max01_tmp, src201);
int8x8_t max12_tmp = vmax_s8(src012, src112);
max12_tmp = vmax_s8(max12_tmp, src212);
#define cb(i) \
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \
max12_tmp[i + 4]);
#define store(i) *(dptr + i) = dst##i;
UNROLL_CALL_NOWRAPPER(4, cb)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef cb
sptr0 += 4;
sptr1 += 4;
sptr2 += 4;
dptr += 4;
}
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
namespace megdnn {
namespace arm_common {
void do_max_pooling_3x3_s1x1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t OH,
size_t OW, size_t PH, size_t PW);
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h" #include "src/arm_common/pooling/do_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
...@@ -19,12 +20,16 @@ namespace arm_common { ...@@ -19,12 +20,16 @@ namespace arm_common {
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW) { size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh; size_t ih = oh;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
...@@ -46,15 +51,10 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -46,15 +51,10 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
} }
for (; ow < OW; ++ow) { for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0); int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src012 = vld1_s8(sptr0 + 4);
int8x8_t src101 = vld1_s8(sptr1); int8x8_t src101 = vld1_s8(sptr1);
int8x8_t src112 = vld1_s8(sptr1 + 4);
int8x8_t max01_tmp = vmax_s8(src001, src101); int8x8_t max_out = vmax_s8(src001, src101);
int8x8_t max12_tmp = vmax_s8(src012, src112); #define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]);
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER(4, store) UNROLL_CALL_NOWRAPPER(4, store)
#undef store #undef store
sptr0 += 4; sptr0 += 4;
...@@ -66,12 +66,16 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -66,12 +66,16 @@ void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW) { size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh << 1; size_t ih = oh << 1;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
...@@ -103,15 +107,10 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -103,15 +107,10 @@ void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
} }
for (; ow < OW; ++ow) { for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0); int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src012 = vld1_s8(sptr0 + 4);
int8x8_t src101 = vld1_s8(sptr1); int8x8_t src101 = vld1_s8(sptr1);
int8x8_t src112 = vld1_s8(sptr1 + 4);
int8x8_t max01_tmp = vmax_s8(src001, src101); int8x8_t max_out = vmax_s8(src001, src101);
int8x8_t max12_tmp = vmax_s8(src012, src112); #define store(i) *(dptr + i) = std::max(max_out[i], max_out[i + 4]);
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER(4, store) UNROLL_CALL_NOWRAPPER(4, store)
#undef store #undef store
sptr0 += 8; sptr0 += 8;
......
...@@ -18,11 +18,13 @@ namespace arm_common { ...@@ -18,11 +18,13 @@ namespace arm_common {
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW); size_t PH, size_t PW,
const WorkspaceBundle& ws);
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW); size_t PH, size_t PW,
const WorkspaceBundle& ws);
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
...@@ -9,60 +9,143 @@ ...@@ -9,60 +9,143 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h" #include "src/arm_common/pooling/do_pooling_3x3_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t OH, size_t IH, size_t IW,
size_t OW, size_t PH, size_t PW) { size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int8x16_t src0123 = vld1q_s8(sptr0);
int8x16_t src1234 = vld1q_s8(sptr0 + 4);
int8x16_t src2345 = vld1q_s8(sptr0 + 8);
int8x16_t max0 = vmaxq_s8(src0123, src1234);
max0 = vmaxq_s8(max0, src2345);
src0123 = vld1q_s8(sptr1);
src1234 = vld1q_s8(sptr1 + 4);
src2345 = vld1q_s8(sptr1 + 8);
int8x16_t max1 = vmaxq_s8(src0123, src1234);
max1 = vmaxq_s8(max1, src2345);
src0123 = vld1q_s8(sptr2);
src1234 = vld1q_s8(sptr2 + 4);
src2345 = vld1q_s8(sptr2 + 8);
int8x16_t max2 = vmaxq_s8(src0123, src1234);
max2 = vmaxq_s8(max2, src2345);
int8x16_t max_out = vmaxq_s8(max0, max1);
max_out = vmaxq_s8(max_out, max2);
vst1q_s8(dptr, max_out);
sptr0 += 16;
sptr1 += 16;
sptr2 += 16;
dptr += 16;
}
for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src012 = vld1_s8(sptr0 + 4);
int8x8_t src101 = vld1_s8(sptr1);
int8x8_t src112 = vld1_s8(sptr1 + 4);
int8x8_t src201 = vld1_s8(sptr2);
int8x8_t src212 = vld1_s8(sptr2 + 4);
int8x8_t max01_tmp = vmax_s8(src001, src101);
max01_tmp = vmax_s8(max01_tmp, src201);
int8x8_t max12_tmp = vmax_s8(src012, src112);
max12_tmp = vmax_s8(max12_tmp, src212);
#define cb(i) \
int8_t dst##i = std::max(std::max(max01_tmp[i], max01_tmp[i + 4]), \
max12_tmp[i + 4]);
#define store(i) *(dptr + i) = dst##i;
UNROLL_CALL_NOWRAPPER(4, cb)
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
#undef cb
sptr0 += 4;
sptr1 += 4;
sptr2 += 4;
dptr += 4;
}
}
}
void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh << 1; size_t ih = oh << 1;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
int8x16_t src00 = vld1q_s8(sptr0); int8x16_t src00 = vld1q_s8(sptr0);
int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4); int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4);
int8x16_t src08 = vld1q_s8(sptr0 + 4 * 8); int32x4_t src08 = vld1q_dup_s32(
reinterpret_cast<const int32_t*>(sptr0 + 4 * 8));
int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00),
vreinterpretq_s32_s8(src04)); vreinterpretq_s32_s8(src04));
int32x4_t src0246 = src_tmp.val[0]; int32x4_t src0246 = src_tmp.val[0];
int32x4_t src1357 = src_tmp.val[1]; int32x4_t src1357 = src_tmp.val[1];
int32x4_t src2468 = int32x4_t src2468 = vextq_s32(src0246, src08, 1);
vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1);
int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), int8x16_t max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246),
vreinterpretq_s8_s32(src1357)); vreinterpretq_s8_s32(src1357));
int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); int8x16_t max0 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468));
int8x16_t src10 = vld1q_s8(sptr1); int8x16_t src10 = vld1q_s8(sptr1);
int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4); int8x16_t src14 = vld1q_s8(sptr1 + 4 * 4);
int8x16_t src18 = vld1q_s8(sptr1 + 4 * 8); int32x4_t src18 = vld1q_dup_s32(
reinterpret_cast<const int32_t*>(sptr1 + 4 * 8));
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10), src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src10),
vreinterpretq_s32_s8(src14)); vreinterpretq_s32_s8(src14));
src0246 = src_tmp.val[0]; src0246 = src_tmp.val[0];
src1357 = src_tmp.val[1]; src1357 = src_tmp.val[1];
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src18), 1); src2468 = vextq_s32(src0246, src18, 1);
max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246),
vreinterpretq_s8_s32(src1357)); vreinterpretq_s8_s32(src1357));
int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468)); int8x16_t max1 = vmaxq_s8(max_tmp, vreinterpretq_s8_s32(src2468));
int8x16_t src20 = vld1q_s8(sptr2); int8x16_t src20 = vld1q_s8(sptr2);
int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4); int8x16_t src24 = vld1q_s8(sptr2 + 4 * 4);
int8x16_t src28 = vld1q_s8(sptr2 + 4 * 8); int32x4_t src28 = vld1q_dup_s32(
reinterpret_cast<const int32_t*>(sptr2 + 4 * 8));
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20), src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src20),
vreinterpretq_s32_s8(src24)); vreinterpretq_s32_s8(src24));
src0246 = src_tmp.val[0]; src0246 = src_tmp.val[0];
src1357 = src_tmp.val[1]; src1357 = src_tmp.val[1];
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src28), 1); src2468 = vextq_s32(src0246, src28, 1);
max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246), max_tmp = vmaxq_s8(vreinterpretq_s8_s32(src0246),
vreinterpretq_s8_s32(src1357)); vreinterpretq_s8_s32(src1357));
......
...@@ -15,9 +15,16 @@ ...@@ -15,9 +15,16 @@
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
void do_max_pooling_3x3_s2x2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_3x3_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW,
const WorkspaceBundle& ws);
void do_max_pooling_3x3_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t OH, size_t IH, size_t IW, size_t OH,
size_t OW, size_t PH, size_t PW); size_t OW, size_t PH, size_t PW,
const WorkspaceBundle& ws);
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#include "src/arm_common/pooling/do_max_pooling_4x4_nchw44.h" #include "src/arm_common/pooling/do_pooling_4x4_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
...@@ -19,14 +20,18 @@ namespace arm_common { ...@@ -19,14 +20,18 @@ namespace arm_common {
void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW) { size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh; size_t ih = oh;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
...@@ -90,35 +95,38 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -90,35 +95,38 @@ void do_max_pooling_4x4_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_4x4_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW) { size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh << 1; size_t ih = oh << 1;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* __restrict sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* __restrict sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; const int8_t* __restrict sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; const int8_t* __restrict sptr3 = sptr + (ih + 3) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
int8x16_t src00, src04, src08, src09, max_tmp0, max_tmp1, max_tmp2, int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3;
max_tmp3; int32x4_t src0246, src1357, src2468, src3579, src08, src09;
int32x4_t src0246, src1357, src2468, src3579;
int32x4x2_t src_tmp; int32x4x2_t src_tmp;
#define CACULATE_ROW(i) \ #define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \ src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \ src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_s8(sptr##i + 4 * 8); \ src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src09 = vld1q_s8(sptr##i + 4 * 9); \ src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \ vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \ src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \ src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ src2468 = vextq_s32(src0246, src08, 1); \
src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ src3579 = vextq_s32(src1357, src09, 1); \
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \
vreinterpretq_s8_s32(src1357)); \ vreinterpretq_s8_s32(src1357)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579));
UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW) UNROLL_CALL_NOWRAPPER(4, CACULATE_ROW)
......
...@@ -18,7 +18,7 @@ namespace arm_common { ...@@ -18,7 +18,7 @@ namespace arm_common {
#define KERN(strdie) \ #define KERN(strdie) \
void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \ void do_max_pooling_4x4_##strdie##_int8_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW); size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN(stride1) KERN(stride1)
KERN(stride2) KERN(stride2)
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#include "src/arm_common/pooling/do_max_pooling_5x5_nchw44.h" #include "src/arm_common/pooling/do_pooling_5x5_nchw44.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
...@@ -19,15 +20,19 @@ namespace arm_common { ...@@ -19,15 +20,19 @@ namespace arm_common {
void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW) { size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh; size_t ih = oh;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; const int8_t* sptr4 = sptr + (ih + 4) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
...@@ -80,13 +85,16 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -80,13 +85,16 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
max_out = vmax_s8(max_out, max_tmp3); max_out = vmax_s8(max_out, max_tmp3);
max_out = vmax_s8(max_out, max_tmp4); max_out = vmax_s8(max_out, max_tmp4);
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); #define COMPARE_SRC45(i) \
int32x2_t src##i##_45 = \
vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4));
UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45)
int8x8_t max_45 = vmax_s8(src0_45, src1_45); int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45),
max_45 = vmax_s8(max_45, src1_45); vreinterpret_s8_s32(src1_45));
max_45 = vmax_s8(max_45, src2_45); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45));
max_45 = vmax_s8(max_45, src3_45); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45));
max_45 = vmax_s8(max_45, src4_45); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45));
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src4_45));
#define store(i) \ #define store(i) \
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]);
...@@ -106,39 +114,44 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -106,39 +114,44 @@ void do_max_pooling_5x5_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW, size_t IH, size_t IW,
size_t OH, size_t OW, size_t OH, size_t OW,
size_t PH, size_t PW) { size_t PH, size_t PW,
const WorkspaceBundle& ws) {
const int8_t* sptr = nullptr;
size_t IH2, IW2;
sptr = handle_padding(src, IH, IW, IH2, IW2, PH, PW, ws);
size_t oh = 0; size_t oh = 0;
for (; oh < OH; ++oh) { for (; oh < OH; ++oh) {
size_t ih = oh << 1; size_t ih = oh << 1;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4; const int8_t* sptr0 = sptr + (ih + 0) * IW2 * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4; const int8_t* sptr1 = sptr + (ih + 1) * IW2 * 4;
const int8_t* __restrict sptr2 = src + (ih + 2) * IW * 4; const int8_t* sptr2 = sptr + (ih + 2) * IW2 * 4;
const int8_t* __restrict sptr3 = src + (ih + 3) * IW * 4; const int8_t* sptr3 = sptr + (ih + 3) * IW2 * 4;
const int8_t* __restrict sptr4 = src + (ih + 4) * IW * 4; const int8_t* sptr4 = sptr + (ih + 4) * IW2 * 4;
int8_t* __restrict dptr = dst + oh * OW * 4; int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0; size_t ow = 0;
for (; ow + 3 < OW; ow += 4) { for (; ow + 3 < OW; ow += 4) {
int8x16_t src00, src04, src08, src09, src10, max_tmp0, max_tmp1, int8x16_t src00, src04, max_tmp0, max_tmp1, max_tmp2, max_tmp3,
max_tmp2, max_tmp3, max_tmp4; max_tmp4;
int32x4_t src0246, src1357, src2468, src3579, src46810; int32x4_t src0246, src1357, src2468, src3579, src46810, src10,
src09, src08;
int32x4x2_t src_tmp; int32x4x2_t src_tmp;
#define CACULATE_ROW(i) \ #define CACULATE_ROW(i) \
src00 = vld1q_s8(sptr##i); \ src00 = vld1q_s8(sptr##i); \
src04 = vld1q_s8(sptr##i + 4 * 4); \ src04 = vld1q_s8(sptr##i + 4 * 4); \
src08 = vld1q_s8(sptr##i + 4 * 8); \ src08 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 8)); \
src09 = vld1q_s8(sptr##i + 4 * 9); \ src09 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 9)); \
src10 = vld1q_s8(sptr##i + 4 * 10); \ src10 = vld1q_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 10)); \
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \ src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00), \
vreinterpretq_s32_s8(src04)); \ vreinterpretq_s32_s8(src04)); \
src0246 = src_tmp.val[0]; \ src0246 = src_tmp.val[0]; \
src1357 = src_tmp.val[1]; \ src1357 = src_tmp.val[1]; \
src2468 = vextq_s32(src0246, vreinterpretq_s32_s8(src08), 1); \ src2468 = vextq_s32(src0246, src08, 1); \
src3579 = vextq_s32(src1357, vreinterpretq_s32_s8(src09), 1); \ src3579 = vextq_s32(src1357, src09, 1); \
src46810 = vextq_s32(src2468, vreinterpretq_s32_s8(src10), 1); \ src46810 = vextq_s32(src2468, src10, 1); \
max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \ max_tmp##i = vmaxq_s8(vreinterpretq_s8_s32(src0246), \
vreinterpretq_s8_s32(src1357)); \ vreinterpretq_s8_s32(src1357)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src2468)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \ max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src3579)); \
max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810)); max_tmp##i = vmaxq_s8(max_tmp##i, vreinterpretq_s8_s32(src46810));
UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW) UNROLL_CALL_NOWRAPPER(5, CACULATE_ROW)
...@@ -173,13 +186,16 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst, ...@@ -173,13 +186,16 @@ void do_max_pooling_5x5_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
max_out = vmax_s8(max_out, max_tmp3); max_out = vmax_s8(max_out, max_tmp3);
max_out = vmax_s8(max_out, max_tmp4); max_out = vmax_s8(max_out, max_tmp4);
#define COMPARE_SRC45(i) int8x8_t src##i##_45 = vld1_s8(sptr##i + 4 * 4); #define COMPARE_SRC45(i) \
int32x2_t src##i##_45 = \
vld1_dup_s32(reinterpret_cast<const int32_t*>(sptr##i + 4 * 4));
UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45) UNROLL_CALL_NOWRAPPER(5, COMPARE_SRC45)
int8x8_t max_45 = vmax_s8(src0_45, src1_45); int8x8_t max_45 = vmax_s8(vreinterpret_s8_s32(src0_45),
max_45 = vmax_s8(max_45, src1_45); vreinterpret_s8_s32(src1_45));
max_45 = vmax_s8(max_45, src2_45); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src1_45));
max_45 = vmax_s8(max_45, src3_45); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src2_45));
max_45 = vmax_s8(max_45, src4_45); max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src3_45));
max_45 = vmax_s8(max_45, vreinterpret_s8_s32(src4_45));
#define store(i) \ #define store(i) \
*(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]); *(dptr + i) = std::max(std::max(max_out[i], max_out[i + 4]), max_45[i]);
......
...@@ -18,7 +18,7 @@ namespace arm_common { ...@@ -18,7 +18,7 @@ namespace arm_common {
#define KERN(strdie) \ #define KERN(strdie) \
void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \ void do_max_pooling_5x5_##strdie##_int8_nchw44_NEON( \
const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \ const int8_t* src, int8_t* dst, size_t IH, size_t IW, size_t OH, \
size_t OW, size_t PH, size_t PW); size_t OW, size_t PH, size_t PW, const WorkspaceBundle& ws);
KERN(stride1) KERN(stride1)
KERN(stride2) KERN(stride2)
......
...@@ -26,8 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj { ...@@ -26,8 +26,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2; AlgoInt8Filter2MaxStride2 algo_int8_filter2_max_stride2;
AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2; AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2;
AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4; AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4;
AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4; AlgoFilter3MaxStridexNCHW44 algo_filter3_max_stridex_nchw4;
AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4;
AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4; AlgoFilter4MaxStridexNCHW44 algo_filter4_max_stridex_nchw4;
AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4; AlgoFilter5MaxStridexNCHW44 algo_filter5_max_stridex_nchw4;
...@@ -41,8 +40,7 @@ public: ...@@ -41,8 +40,7 @@ public:
all_algos.emplace_back(&algo_filter5_max_stride2); all_algos.emplace_back(&algo_filter5_max_stride2);
all_algos.emplace_back(&algo_int8_filter2_max_stride2); all_algos.emplace_back(&algo_int8_filter2_max_stride2);
all_algos.emplace_back(&algo_int8_filter3_max_stride2); all_algos.emplace_back(&algo_int8_filter3_max_stride2);
all_algos.emplace_back(&algo_filter3_max_stride2_nchw4); all_algos.emplace_back(&algo_filter3_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter3_max_stride1_nchw4);
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4); all_algos.emplace_back(&algo_filter2_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter4_max_stridex_nchw4); all_algos.emplace_back(&algo_filter4_max_stridex_nchw4);
all_algos.emplace_back(&algo_filter5_max_stridex_nchw4); all_algos.emplace_back(&algo_filter5_max_stridex_nchw4);
...@@ -119,6 +117,12 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src, ...@@ -119,6 +117,12 @@ size_t PoolingImpl::get_workspace_in_bytes(const TensorLayout& src,
arm_common_workspace = ws.total_size_in_bytes() * nr_threads; arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
} }
if ((param.src_type.enumv() == DTypeEnum::QuantizedS8) &&
(param.format == param::Pooling::Format::NCHW44)) {
WorkspaceBundle ws = get_bundle_nchw44(param);
arm_common_workspace = ws.total_size_in_bytes() * nr_threads;
}
if (find_algo) { if (find_algo) {
return arm_common_workspace; return arm_common_workspace;
} else { } else {
......
...@@ -84,8 +84,7 @@ private: ...@@ -84,8 +84,7 @@ private:
class AlgoInt8Filter2MaxStride2; class AlgoInt8Filter2MaxStride2;
class AlgoInt8Filter3MaxStride2; class AlgoInt8Filter3MaxStride2;
class AlgoFilter2MaxStridexNCHW44; class AlgoFilter2MaxStridexNCHW44;
class AlgoFilter3MaxStride2NCHW44; class AlgoFilter3MaxStridexNCHW44;
class AlgoFilter3MaxStride1NCHW44;
class AlgoFilter4MaxStridexNCHW44; class AlgoFilter4MaxStridexNCHW44;
class AlgoFilter5MaxStridexNCHW44; class AlgoFilter5MaxStridexNCHW44;
class AlgoPack; class AlgoPack;
......
...@@ -8,8 +8,6 @@ ...@@ -8,8 +8,6 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megdnn/dtype.h"
#include "megdnn/opr_param_defs.h"
#include "test/arm_common/fixture.h" #include "test/arm_common/fixture.h"
#include "test/common/pooling.h" #include "test/common/pooling.h"
...@@ -102,209 +100,6 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2) ...@@ -102,209 +100,6 @@ TEST_F(ARM_COMMON, POOLING_INT8_W3x3_S2x2)
// clang-format on // clang-format on
} }
TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 3;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S1x1_NCHW44)
{
// clang-format off
for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 3;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S1x1_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S1x1_NCHW44)
{
// clang-format off
for (size_t ih: {4, 7, 10, 17, 20})
for (size_t iw: {4, 8, 10, 21, 32})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 4;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W4x4_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {4, 10, 18, 25, 30})
for (size_t iw: {4, 12, 17, 20, 25})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 4;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S1x1_NCHW44)
{
// clang-format off
for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 5 && iw+2*pw >= 5)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 5;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W5x5_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 5 && iw+2*pw >= 5)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 5;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, POOLING_FP16) { TEST_F(ARM_COMMON, POOLING_FP16) {
Checker<Pooling> checker(handle()); Checker<Pooling> checker(handle());
......
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include <vector>
#include "megdnn/dtype.h"
#include "test/arm_common/fixture.h" #include "test/arm_common/fixture.h"
#include "test/common/pooling.h" #include "test/common/pooling.h"
...@@ -53,38 +55,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) { ...@@ -53,38 +55,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING) {
checker.set_param(param).exec({{2, 3, ih, iw}, {}}); checker.set_param(param).exec({{2, 3, ih, iw}, {}});
} }
} }
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param; TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_NCHW44)
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 3;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44)
{ {
// clang-format off // clang-format off
for (size_t ih: {3, 5, 10}) for (size_t ih: {3, 5, 10})
for (size_t iw: {3, 5, 7, 9, 15, 20}) for (size_t iw: {3, 5, 7, 9, 15, 20})
for (size_t ph: {0}) for (size_t ph: {0, 1, 2})
for (size_t pw: {0}) for (size_t pw: {0, 1, 2})
if (ih+2*ph >= 3 && iw+2*pw >= 3) if (ih+2*ph >= 3 && iw+2*pw >= 3)
{ {
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
...@@ -100,18 +78,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44) ...@@ -100,18 +78,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44)
param.stride_h = param.stride_w = 1; param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 3; param.window_h = param.window_w = 3;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
param.stride_h = param.stride_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
} }
// clang-format on // clang-format on
} }
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_NCHW44)
{ {
// clang-format off // clang-format off
for (size_t ih: {2, 5, 10, 17}) for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26}) for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0}) for (size_t ph: {0, 1})
for (size_t pw: {0}) for (size_t pw: {0, 1})
if (ih+2*ph >= 3 && iw+2*pw >= 3) if (ih+2*ph >= 2 && iw+2*pw >= 2)
{ {
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle()); Checker<Pooling> checker(handle());
...@@ -126,41 +108,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44) ...@@ -126,41 +108,20 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44)
param.stride_h = param.stride_w = 1; param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 2; param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2; param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
} }
// clang-format on // clang-format on
} }
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44)
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_NCHW44)
{ {
// clang-format off // clang-format off
for (size_t ih: {4, 7, 10, 17, 20}) for (size_t ih: {4, 10, 18, 25, 30})
for (size_t iw: {4, 8, 10, 21, 32}) for (size_t iw: {4, 12, 17, 20, 25})
for (size_t ph: {0}) for (size_t ph: {0, 1, 2})
for (size_t pw: {0}) for (size_t pw: {0, 1, 2})
if (ih+2*ph >= 4 && iw+2*pw >= 4) if (ih+2*ph >= 4 && iw+2*pw >= 4)
{ {
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
...@@ -176,41 +137,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44) ...@@ -176,41 +137,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S1x1_NCHW44)
param.stride_h = param.stride_w = 1; param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 4; param.window_h = param.window_w = 4;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W4x4_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {4, 10, 18, 25, 30})
for (size_t iw: {4, 12, 17, 20, 25})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 4 && iw+2*pw >= 4)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2; param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 4;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
} }
// clang-format on // clang-format on
} }
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_NCHW44)
{ {
// clang-format off // clang-format off
for (size_t ih: {5, 9, 19, 20, 39}) for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39}) for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0}) for (size_t ph: {0, 1, 2})
for (size_t pw: {0}) for (size_t pw: {0, 1, 2})
if (ih+2*ph >= 5 && iw+2*pw >= 5) if (ih+2*ph >= 5 && iw+2*pw >= 5)
{ {
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1}; UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
...@@ -226,31 +165,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44) ...@@ -226,31 +165,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S1x1_NCHW44)
param.stride_h = param.stride_w = 1; param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 5; param.window_h = param.window_w = 5;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W5x5_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {5, 9, 19, 20, 39})
for (size_t iw: {5, 12, 23, 27, 39})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 5 && iw+2*pw >= 5)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2; param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 5;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}}); checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
} }
// clang-format on // clang-format on
} }
...@@ -473,13 +391,15 @@ template <typename Opr> ...@@ -473,13 +391,15 @@ template <typename Opr>
void benchmark_impl(const typename Opr::Param& param, void benchmark_impl(const typename Opr::Param& param,
std::vector<SmallVector<TensorShape>> shapes, size_t RUNS, std::vector<SmallVector<TensorShape>> shapes, size_t RUNS,
TaskExecutorConfig&& multi_thread_config, TaskExecutorConfig&& multi_thread_config,
TaskExecutorConfig&& single_thread_config) { TaskExecutorConfig&& single_thread_config,
DType data_type) {
std::vector<float> multi_thread_times, single_thread_times; std::vector<float> multi_thread_times, single_thread_times;
{ {
auto multi_thread_hanle = auto multi_thread_hanle =
create_cpu_handle(0, true, &multi_thread_config); create_cpu_handle(0, true, &multi_thread_config);
auto benchmarker = Benchmarker<Opr>(multi_thread_hanle.get()); auto benchmarker = Benchmarker<Opr>(multi_thread_hanle.get());
benchmarker.set_times(RUNS).set_display(false).set_param(param); benchmarker.set_times(RUNS).set_display(false).set_param(param);
benchmarker.set_dtype(0, data_type);
for (auto shape : shapes) { for (auto shape : shapes) {
multi_thread_times.push_back(benchmarker.exec(shape) / RUNS); multi_thread_times.push_back(benchmarker.exec(shape) / RUNS);
} }
...@@ -489,6 +409,7 @@ void benchmark_impl(const typename Opr::Param& param, ...@@ -489,6 +409,7 @@ void benchmark_impl(const typename Opr::Param& param,
create_cpu_handle(0, true, &single_thread_config); create_cpu_handle(0, true, &single_thread_config);
auto benchmarker = Benchmarker<Opr>(single_thread_handle.get()); auto benchmarker = Benchmarker<Opr>(single_thread_handle.get());
benchmarker.set_times(RUNS).set_display(false).set_param(param); benchmarker.set_times(RUNS).set_display(false).set_param(param);
benchmarker.set_dtype(0, data_type);
for (auto shape : shapes) { for (auto shape : shapes) {
single_thread_times.push_back(benchmarker.exec(shape) / RUNS); single_thread_times.push_back(benchmarker.exec(shape) / RUNS);
} }
...@@ -540,10 +461,47 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) { ...@@ -540,10 +461,47 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING) {
param.stride_h = param.stride_w = 2; param.stride_h = param.stride_w = 2;
param.pad_h = param.pad_w = 1; param.pad_h = param.pad_w = 1;
printf("Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h, printf("Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h,
param.stride_h, param.pad_h, static_cast<int>(param.mode)); param.window_w, param.stride_h, static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}); benchmark_impl<Pooling>(param, shapes, RUNS, {4, {0, 1, 2, 3}}, {1, {0}}, dtype::Float32());
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}); benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}}, dtype::Float32());
benchmark_impl<Pooling>(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}); benchmark_impl<Pooling>(param, shapes, RUNS, {2, {0, 1}}, {1, {0}}, dtype::Float32());
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_POOLING_NCHW44) {
constexpr size_t RUNS = 50;
using Param = param::Pooling;
Param param;
param.pad_h = param.pad_w = 0;
param.mode = Param::Mode::MAX;
std::vector<SmallVector<TensorShape>> shapes;
std::vector<std::vector<size_t>> filter_and_stride = {
{2, 1}, {2, 2}, {3, 1}, {3, 2}, {4, 1}, {4, 2}, {5, 1}, {5, 2}};
for(auto filter:filter_and_stride){
shapes.push_back({{1, 32 * 4, 215, 215}, {}});
shapes.push_back({{1, 32 * 4, 128, 128}, {}});
shapes.push_back({{1, 16 * 4, 56, 56}, {}});
param.window_h = param.window_w = filter[0];
param.stride_h = param.stride_w = filter[1];
param.format = Param::Format::NCHW;
printf("NCHW Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h,
param.window_h, param.stride_h, static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
dtype::QuantizedS8(1.1f));
shapes.clear();
shapes.push_back({{1, 32, 215, 215,4}, {}});
shapes.push_back({{1, 32, 128, 128,4}, {}});
shapes.push_back({{1, 16, 56, 56, 4}, {}});
param.format = Param::Format::NCHW44;
printf("NCHW44 Benchmark POOLING kernel:%d*%d stride:%d,mode %d\n", param.window_h,
param.window_w, param.stride_h, static_cast<int>(param.mode));
benchmark_impl<Pooling>(param, shapes, RUNS, {4, {4, 5, 6, 7}}, {1, {4}},
dtype::QuantizedS8(1.1f));
shapes.clear();
}
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册