提交 c9d265c7 编写于 作者: M Megvii Engine Team

feat(dnn): add fp16 nchw88 pooling algo

GitOrigin-RevId: 7a5e9c7df242fd5d7d7811b1af9213e58be20e91
上级 222410b0
...@@ -124,6 +124,19 @@ public: ...@@ -124,6 +124,19 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44) MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44)
}; };
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class PoolingImpl::AlgoFilterxModexStridexNCHW88 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override {
return "ARM_POOLING_FILTERX_MODEX_STRIDEX_NCHW88";
}
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Fp16FilterxModexStridexNCHW88)
};
#endif
class PoolingImpl::AlgoFallback final : public AlgoBase { class PoolingImpl::AlgoFallback final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }; AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
......
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h"
MIDOUT_DECL(megdnn_arm_common_fp16_nchw88_pooling)
namespace megdnn {
namespace arm_common {
bool PoolingImpl::AlgoFilterxModexStridexNCHW88::usable(
const PoolingKernSizeParam& param) const {
uint32_t sh = param.stride[0];
uint32_t sw = param.stride[1];
uint32_t fh = param.filter[0];
uint32_t fw = param.filter[1];
bool usable = param.src_type.enumv() == DTypeEnum::Float16 &&
param.format == param::Pooling::Format::NCHW88 &&
(param.mode == PoolingBase::Mode::MAX ||
param.mode == PoolingBase::Mode::AVERAGE) &&
fh == fw && sh == sw;
bool size_ok =
(((fh == 2 || fh == 3 || fh == 4 || fh == 5) && (sh == 1 || sh == 2)) ||
((fh == 9 || fh == 13) && (sh == 1)));
return usable && size_ok;
}
void PoolingImpl::AlgoFilterxModexStridexNCHW88::exec(
const PoolingKernParam& param) const {
int ih = param.isz[0];
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
int n = param.n;
int ic = param.ic;
int ph = param.padding[0];
int pw = param.padding[1];
int sh = param.stride[0];
int fh = param.filter[0];
auto src = param.src_ptr;
auto dst = param.dst_ptr;
#define DISPATCH_FUNC(filter, stride, mode) \
MIDOUT_BEGIN( \
megdnn_arm_common_fp16_nchw88_pooling, midout_iv(0), \
midout_iv(#filter #stride #mode##_hash)) { \
auto run = [=](size_t index, size_t) { \
const int c_idx = index; \
pooling_fp16_nchw88<filter, stride, mode>( \
static_cast<const __fp16*>(src.get_ptr()) + c_idx * ih * iw * 8, \
static_cast<__fp16*>(dst.get_ptr()) + c_idx * oh * ow * 8, ih, iw, \
oh, ow, ph, pw); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \
} \
MIDOUT_END();
#define DISPATCH_MODE(filter, stride) \
switch (param.mode) { \
case PoolingBase::Mode::MAX: \
DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \
break; \
case PoolingBase::Mode::AVERAGE: \
DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \
break; \
default: \
megdnn_assert(0, "invalid mode %u", static_cast<uint32_t>(param.mode)); \
}
#define DISPATCH_STRIDE(filter) \
switch (sh) { \
case 1: \
DISPATCH_MODE(filter, 1); \
break; \
case 2: \
DISPATCH_MODE(filter, 2); \
break; \
default: \
megdnn_assert( \
0, \
"Invalid stride %d. When the filter size is 2, 3, 4 or 5, stride " \
"can only be 1 or 2.", \
sh); \
}
#define DISPATCH_STRIDE1(filter) \
switch (sh) { \
case 1: \
DISPATCH_MODE(filter, 1); \
break; \
default: \
megdnn_assert( \
0, \
"Invalid stride %d. When the filter size is 9 or 13, stride " \
"can only be 1.", \
sh); \
}
#define DISPATCH_FILTER() \
switch (fh) { \
case 2: \
DISPATCH_STRIDE(2); \
break; \
case 3: \
DISPATCH_STRIDE(3); \
break; \
case 4: \
DISPATCH_STRIDE(4); \
break; \
case 5: \
DISPATCH_STRIDE(5); \
break; \
case 9: \
DISPATCH_STRIDE1(9); \
break; \
case 13: \
DISPATCH_STRIDE1(13); \
break; \
}
DISPATCH_FILTER();
}
} // namespace arm_common
} // namespace megdnn
#endif
\ No newline at end of file
...@@ -22,6 +22,9 @@ private: ...@@ -22,6 +22,9 @@ private:
AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4; AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4;
AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4; AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4; AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoFilterxModexStridexNCHW88 algo_fp16_filterx_modex_stridex_nchw88;
#endif
AlgoFallback algo_fallback; AlgoFallback algo_fallback;
public: public:
...@@ -38,6 +41,9 @@ public: ...@@ -38,6 +41,9 @@ public:
all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4); all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&algo_fp16_filterx_modex_stridex_nchw88);
#endif
all_algos.emplace_back(&algo_fallback); all_algos.emplace_back(&algo_fallback);
for (auto&& algo : all_algos) { for (auto&& algo : all_algos) {
......
...@@ -24,6 +24,9 @@ private: ...@@ -24,6 +24,9 @@ private:
class AlgoFilter3ModexStridexNCHW44; class AlgoFilter3ModexStridexNCHW44;
class AlgoFilter4ModexStridexNCHW44; class AlgoFilter4ModexStridexNCHW44;
class AlgoFilter5ModexStridexNCHW44; class AlgoFilter5ModexStridexNCHW44;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoFilterxModexStridexNCHW88;
#endif
class AlgoFallback; class AlgoFallback;
class AlgoPack; class AlgoPack;
static AlgoPack sm_algo_pack; static AlgoPack sm_algo_pack;
...@@ -56,6 +59,9 @@ public: ...@@ -56,6 +59,9 @@ public:
ARM_Filter3ModexStridexNCHW44, ARM_Filter3ModexStridexNCHW44,
ARM_Filter4ModexStridexNCHW44, ARM_Filter4ModexStridexNCHW44,
ARM_Filter5ModexStridexNCHW44, ARM_Filter5ModexStridexNCHW44,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
ARM_Fp16FilterxModexStridexNCHW88,
#endif
ARM_Fp32ModexStridexNCHW44, ARM_Fp32ModexStridexNCHW44,
ARM_Fallback ARM_Fallback
}; };
......
...@@ -165,6 +165,10 @@ ...@@ -165,6 +165,10 @@
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \ cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \
cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a) cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a)
#define UNROLL_RAW_3x2(cb, v0, a...) \
UNROLL_RAW_2x2(cb, v0, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a)
#define UNROLL_RAW_4x2(cb, v0, a...) \ #define UNROLL_RAW_4x2(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \ cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a) cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a)
...@@ -177,6 +181,19 @@ ...@@ -177,6 +181,19 @@
UNROLL_RAW_5x2(cb, v0, ##a) \ UNROLL_RAW_5x2(cb, v0, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 0, ##a) cb(5, 1, ##a)
#define UNROLL_RAW_9x2(cb, v0, a...) \
UNROLL_RAW_6x2(cb, v0, ##a) \
cb(6, 0, ##a) cb(6, 1, ##a) \
cb(7, 0, ##a) cb(7, 1, ##a) \
cb(8, 0, ##a) cb(8, 1, ##a)
#define UNROLL_RAW_13x2(cb, v0, a...) \
UNROLL_RAW_9x2(cb, v0, ##a) \
cb(9, 0, ##a) cb(9, 1, ##a) \
cb(10, 0, ##a) cb(10, 1, ##a) \
cb(11, 0, ##a) cb(11, 1, ##a) \
cb(12, 0, ##a) cb(12, 1, ##a)
#define UNROLL_RAW_4x6(cb, v0, a...) \ #define UNROLL_RAW_4x6(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) cb(0, 4, ##a) cb(0, 5, ##a) \ cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) cb(0, 4, ##a) cb(0, 5, ##a) \
cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) cb(1, 4, ##a) cb(1, 5, ##a) \ cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) cb(1, 4, ##a) cb(1, 5, ##a) \
...@@ -186,6 +203,28 @@ ...@@ -186,6 +203,28 @@
UNROLL_RAW_4x6(cb, v0, ##a) \ UNROLL_RAW_4x6(cb, v0, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) cb(4, 4, ##a) cb(4, 5, ##a) cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) cb(4, 4, ##a) cb(4, 5, ##a)
#define UNROLL_RAW_2x4(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \
cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a)
#define UNROLL_RAW_3x4(cb, v0, a...) \
UNROLL_RAW_2x4(cb, v0, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a)
#define UNROLL_RAW_5x4(cb, v0, a...) \
UNROLL_RAW_4x4(cb, v0, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a)
#define UNROLL_RAW_9x4(cb, v0, a...) \
UNROLL_RAW_5x4(cb, v0, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \
cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \
cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a)
#define UNROLL_RAW_13x4(cb, v0, a...) \
UNROLL_RAW_9x4(cb, v0, ##a) \
cb(9, 0, ##a) cb(9, 1, ##a) cb(9, 2, ##a) cb(9, 3, ##a) \
cb(10, 0, ##a) cb(10, 1, ##a) cb(10, 2, ##a) cb(10, 3, ##a) \
cb(11, 0, ##a) cb(11, 1, ##a) cb(11, 2, ##a) cb(11, 3, ##a) \
cb(12, 0, ##a) cb(12, 1, ##a) cb(12, 2, ##a) cb(12, 3, ##a)
#define UNROLL_CALL0_D2(step, step2, cb, v...) \ #define UNROLL_CALL0_D2(step, step2, cb, v...) \
UNROLL_RAW_##step##x##step2(cb, 0, ##v) UNROLL_RAW_##step##x##step2(cb, 0, ##v)
#define UNROLL_CALL1_D2(step, step2, cb, v...) \ #define UNROLL_CALL1_D2(step, step2, cb, v...) \
......
...@@ -216,6 +216,48 @@ TEST_F(ARM_COMMON, POOLING_FP16) { ...@@ -216,6 +216,48 @@ TEST_F(ARM_COMMON, POOLING_FP16) {
checker.set_param(param).exec({{2, 3, ih, iw}, {}}); checker.set_param(param).exec({{2, 3, ih, iw}, {}});
} }
} }
TEST_F(ARM_COMMON, POOLING_FP16_NCHW88) {
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::Float16{});
checker.set_dtype(1, dtype::Float16{});
checker.set_dtype(2, dtype::Float16{});
checker.set_dtype(4, dtype::Float16{});
checker.set_epsilon(0.003);
for (size_t ic : {1, 2, 3, 5, 7, 11})
for (size_t ih : {20, 15})
for (size_t iw : {15, 20, 27, 51, 76, 101, 256})
for (size_t pad : {2, 3, 5})
for (auto mode :
{param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) {
param::Pooling param;
param.mode = mode;
param.format = param::Pooling::Format::NCHW88;
param.pad_h = pad;
param.pad_w = pad;
for (size_t kernel : {2, 3, 4, 5}) {
if (kernel > pad && ih + 2 * pad >= kernel &&
iw + 2 * pad >= kernel) {
param.window_h = param.window_w = kernel;
param.stride_h = param.stride_w = 1;
checker.set_param(param).exec(
TensorShapeArray{{2, ic, ih, iw, 8}, {}});
param.stride_h = param.stride_w = 2;
checker.set_param(param).exec(
TensorShapeArray{{2, ic, ih, iw, 8}, {}});
}
}
for (size_t kernel : {9, 13}) {
if (kernel > pad && ih + 2 * pad >= kernel &&
iw + 2 * pad >= kernel) {
param.window_h = param.window_w = kernel;
param.stride_h = param.stride_w = 1;
checker.set_param(param).exec(
TensorShapeArray{{2, ic, ih, iw, 8}, {}});
}
}
}
}
#endif #endif
TEST_F(ARM_COMMON, POOLING_QUANTIZED) { TEST_F(ARM_COMMON, POOLING_QUANTIZED) {
...@@ -367,6 +409,72 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW44_FP32) { ...@@ -367,6 +409,72 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW44_FP32) {
benchmark_nchw44_fp32(handle()); benchmark_nchw44_fp32(handle());
} }
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void benchmark_nchw88_fp16(Handle* handle) {
using Param = param::Pooling;
auto run = [&](size_t n, size_t c, size_t h, size_t w, size_t filter, size_t stride,
size_t pad, Param::Mode mode) {
Param param;
param.window_h = param.window_w = filter;
param.stride_h = param.stride_w = stride;
param.pad_h = param.pad_w = pad;
param.format = Param::Format::NCHW44;
param.mode = mode;
TensorShape nchw44_shape = {n, c / 4, h, w, 4};
TensorShape nchw88_shape = {n, c / 8, h, w, 8};
TensorLayout dst_layout;
auto opr = handle->create_operator<Pooling>();
opr->param() = param;
opr->deduce_layout({nchw44_shape, dtype::Float32()}, dst_layout);
float calc_amount =
dst_layout.total_nr_elems() * param.window_h * param.window_w;
Benchmarker<Pooling> benchmarker_float16_nchw88(handle);
Benchmarker<Pooling> benchmarker_float32_nchw44(handle);
size_t RUN = 500;
auto t1 = benchmarker_float32_nchw44.set_display(false)
.set_times(RUN)
.set_param(param)
.exec({nchw44_shape, {}});
param.format = Param::Format::NCHW88;
auto t2 = benchmarker_float16_nchw88.set_display(false)
.set_dtype(0, dtype::Float16{})
.set_dtype(1, dtype::Float16{})
.set_dtype(2, dtype::Float16{})
.set_dtype(4, dtype::Float16{})
.set_times(RUN)
.set_param(param)
.exec({nchw88_shape, {}});
printf("{%zu %zu %zu %zu} filter = %zu, stride = %zu pad = %zu\n"
"nchw44_fp32={%.3f ms, %.3f Mflops}, "
"nchw88_fp16={%.3f ms, %.3f Mflops, speed_up %f}\n\n",
n, c, h, w, filter, stride, pad, t1 / RUN,
calc_amount / (t1 / RUN * 1000), t2 / RUN,
calc_amount / (t2 / RUN * 1000), t1 / t2);
};
// Resnet50
run(1, 64, 112, 112, 3, 2, 1, param::Pooling::Mode::MAX);
run(1, 2048, 7, 7, 7, 1, 0, param::Pooling::Mode::AVERAGE);
// VGG16
run(1, 64, 224, 224, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 128, 112, 112, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 256, 56, 56, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 512, 28, 28, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 512, 14, 14, 2, 2, 0, param::Pooling::Mode::MAX);
}
TEST_F(ARM_COMMON, BENCHMARK_POOLING_NCHW88_FP16) {
benchmark_nchw88_fp16(handle());
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW88_FP16) {
benchmark_nchw88_fp16(handle());
}
#endif
TEST_F(ARM_COMMON, BENCHMARK_POOLING_INT8_W3x3_S2x2) { TEST_F(ARM_COMMON, BENCHMARK_POOLING_INT8_W3x3_S2x2) {
using Param = param::Pooling; using Param = param::Pooling;
auto run = [&](const TensorShapeArray& shapes, Param param) { auto run = [&](const TensorShapeArray& shapes, Param param) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册