From 10bcf75767026f776c285bfd707e7e69ba6fd5f7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 27 Jul 2021 17:25:30 +0800 Subject: [PATCH] feat(dnn/x86): add algo for x86 max pooling for Window size bigger than 10 and S1 under NCHW88 GitOrigin-RevId: 613a18dd916f575fcd29448c2048dde9dae70d1e --- dnn/src/x86/pooling/algo.cpp | 137 ++++++++++++++++++++++++++++++++- dnn/src/x86/pooling/algo.h | 4 +- dnn/src/x86/pooling/opr_impl.h | 1 + dnn/test/x86/pooling.cpp | 100 ++++++++++++++++++++++++ 4 files changed, 240 insertions(+), 2 deletions(-) diff --git a/dnn/src/x86/pooling/algo.cpp b/dnn/src/x86/pooling/algo.cpp index 30929b08..b3432149 100644 --- a/dnn/src/x86/pooling/algo.cpp +++ b/dnn/src/x86/pooling/algo.cpp @@ -20,6 +20,8 @@ #include "src/x86/pooling/pooling_special_cases.h" #include "src/x86/utils.h" +#include "src/x86/avx_helper.h" + using namespace megdnn; using namespace x86; @@ -65,6 +67,7 @@ PoolingImpl::AlgoPack::AlgoPack() { all_algos.push_back(&algo_mean_w2s2_sse3); all_algos.push_back(&algo_max_w2s2_sse); all_algos.push_back(&algo_max_w3s3_sse); + all_algos.push_back(&algo_max_w13s1_nchw88_avx); #if MEGDNN_X86_WITH_MKL_DNN all_algos.push_back(&algo_mkldnn_nchw); all_algos.push_back(&algo_mkldnn_nchw88); @@ -362,4 +365,136 @@ void PoolingImpl::AlgoMKLDNNNCHW88::exec(const ExecArgs& args) const { MEGDNN_DISPATCH_CPU_KERN_OPR(run()); } -#endif \ No newline at end of file +#endif + +namespace { +MEGDNN_ATTRIBUTE_TARGET("avx") +void max_pooling_s1_nchw88_avx_kern(const float* src, float* dst, int IH, + int IW, int OH, int OW, int PH, int PW, + int WH, int WW) { + static float min_float = -std::numeric_limits::max(); + static int VECSIZE = 8; + + __m256 ymm[16]; + const float* psrc = src; + float* pdst = dst; + + //! deal all rows + for (int row = 0; row < IH; ++row) { + for (int j = 0; j < PW; ++j) { + ymm[j] = _mm256_set1_ps(min_float); + } + int col_end = WW - PW < IW ? WW - PW : IW; + for (int j = 0; j < col_end; ++j) { + ymm[j + PW] = _mm256_loadu_ps(psrc + j * VECSIZE); + } + for (int j = col_end + PW; j < WW; ++j) { + ymm[j] = _mm256_set1_ps(min_float); + } + + int col_next = WW - PW; + for (int j = 0; j < OW; ++j) { + for (int i = WW - 2; i >= 0; --i) { + ymm[i] = _mm256_max_ps(ymm[i], ymm[i + 1]); + } + _mm256_storeu_ps(pdst, ymm[0]); + pdst += VECSIZE; + for (int i = 0; i < WW - 1; ++i) { + ymm[i] = ymm[i + 1]; + } + if (col_next < IW) { + ymm[WW - 1] = _mm256_loadu_ps(psrc + col_next * VECSIZE); + col_next++; + } else { + ymm[WW - 1] = _mm256_set1_ps(min_float); + } + } + psrc += IW * VECSIZE; + } + + //! deal all cols + float* src1 = dst; + for (int col = 0; col < OW; ++col) { + for (int j = 0; j < PH; ++j) { + ymm[j] = _mm256_set1_ps(min_float); + } + int row_end = WH - PH < IH ? WH - PH : IH; + for (int j = 0; j < row_end; ++j) { + ymm[j + PH] = _mm256_loadu_ps(src1 + j * OW * VECSIZE); + } + for (int j = row_end + PH; j < WH; ++j) { + ymm[j] = _mm256_set1_ps(min_float); + } + + int row_next = WH - PH; + pdst = src1; + for (int j = 0; j < OH; ++j) { + for (int i = WH - 2; i >= 0; --i) { + ymm[i] = _mm256_max_ps(ymm[i], ymm[i + 1]); + } + _mm256_storeu_ps(pdst, ymm[0]); + pdst += OW * VECSIZE; + for (int i = 0; i < WH - 1; ++i) { + ymm[i] = ymm[i + 1]; + } + if (row_next < IH) { + ymm[WH - 1] = _mm256_loadu_ps(src1 + row_next * OW * VECSIZE); + row_next++; + } else { + ymm[WH - 1] = _mm256_set1_ps(min_float); + } + } + src1 += VECSIZE; + } +} +} // namespace + +bool PoolingImpl::AlgoMaxS1NCHW88AVX::is_available(const SizeArgs& args) const { + bool is_dtype_ok = args.layout_src.dtype == dtype::Float32(); + bool is_mode_ok = args.opr->param().mode == Mode::MAX; + bool is_format_ok = args.opr->param().format == Param::Format::NCHW88; + bool is_shape_ok = args.opr->param().window_h >= 10 && + args.opr->param().window_h <= 15 && + args.opr->param().window_w >= 10 && + args.opr->param().window_w <= 15; + bool is_stride_ok = + args.opr->param().stride_h == 1 && args.opr->param().stride_w == 1; + //! this condition guarantee size of dst's memory is bigger enough because + //! dst's memory will be used as workspace to store intermediate result. + bool is_pad_ok = + args.opr->param().pad_h >= args.opr->param().window_h / 2 && + args.opr->param().pad_w >= args.opr->param().window_w / 2; + bool is_ins_ok = is_supported(SIMDType::AVX); + return is_dtype_ok && is_mode_ok && is_format_ok && is_shape_ok && + is_pad_ok && is_stride_ok && is_ins_ok; +} + +void PoolingImpl::AlgoMaxS1NCHW88AVX::exec(const ExecArgs& args) const { + auto handle = args.handle; + size_t N = args.layout_src.shape[0]; + static size_t VECSIZE = 8; + size_t PH = args.opr->param().pad_h; + size_t PW = args.opr->param().pad_w; + size_t WH = args.opr->param().window_h; + size_t WW = args.opr->param().window_w; + size_t IC = args.layout_src.shape[1]; + size_t IH = args.layout_src.shape[2]; + size_t IW = args.layout_src.shape[3]; + size_t OH = args.layout_dst.shape[2]; + size_t OW = args.layout_dst.shape[3]; + float* src_ptr = reinterpret_cast(args.src_tensor->raw_ptr); + float* dst_ptr = reinterpret_cast(args.dst_tensor->raw_ptr); + + auto run = [IC, src_ptr, dst_ptr, IH, IW, OH, OW, PH, PW, WH, WW]( + size_t index, size_t) { + size_t n = index / IC; + size_t c = index % IC; + float* src = + src_ptr + n * IH * IW * IC * VECSIZE + IH * IW * c * VECSIZE; + float* dst = + dst_ptr + n * OH * OW * IC * VECSIZE + OH * OW * c * VECSIZE; + max_pooling_s1_nchw88_avx_kern(src, dst, IH, IW, OH, OW, PH, PW, WH, + WW); + }; + MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(handle, N * IC, run); +} diff --git a/dnn/src/x86/pooling/algo.h b/dnn/src/x86/pooling/algo.h index 8a16f2b8..1bc54b1b 100644 --- a/dnn/src/x86/pooling/algo.h +++ b/dnn/src/x86/pooling/algo.h @@ -29,6 +29,7 @@ public: X86_MeanW2S2SSE3, X86_MaxW2S2SSE, X86_MaxW3S3SSE, + X86_MaxS1NCHW88AVX, #if MEGDNN_X86_WITH_MKL_DNN X86_MKLDNNNCHW, X86_MKLDNNNCHW88, @@ -87,11 +88,11 @@ ALGO_IMPL(MeanW2S2AVX) ALGO_IMPL(MeanW2S2SSE3) ALGO_IMPL(MaxW2S2SSE) ALGO_IMPL(MaxW3S3SSE) +ALGO_IMPL(MaxS1NCHW88AVX) #if MEGDNN_X86_WITH_MKL_DNN ALGO_IMPL(MKLDNNNCHW) ALGO_IMPL(MKLDNNNCHW88) #endif - #undef ALGO_IMPL class PoolingImpl::AlgoFallback final : public AlgoBase { @@ -118,6 +119,7 @@ private: AlgoMKLDNNNCHW algo_mkldnn_nchw; AlgoMKLDNNNCHW88 algo_mkldnn_nchw88; #endif + AlgoMaxS1NCHW88AVX algo_max_w13s1_nchw88_avx; AlgoFallback algo_fallback; public: diff --git a/dnn/src/x86/pooling/opr_impl.h b/dnn/src/x86/pooling/opr_impl.h index 764c9895..6d0a43c2 100644 --- a/dnn/src/x86/pooling/opr_impl.h +++ b/dnn/src/x86/pooling/opr_impl.h @@ -21,6 +21,7 @@ private: class AlgoMeanW2S2SSE3; class AlgoMaxW2S2SSE; class AlgoMaxW3S3SSE; + class AlgoMaxS1NCHW88AVX; #if MEGDNN_X86_WITH_MKL_DNN class AlgoMKLDNNNCHW; class AlgoMKLDNNNCHW88; diff --git a/dnn/test/x86/pooling.cpp b/dnn/test/x86/pooling.cpp index 3abfd6f7..fa695d88 100644 --- a/dnn/test/x86/pooling.cpp +++ b/dnn/test/x86/pooling.cpp @@ -24,6 +24,70 @@ TEST_F(X86, POOLING) { } } +TEST_F(X86, S1POOLING88) { + Checker checker(handle()); + auto run = [&](size_t WH, size_t WW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t N, size_t C, size_t H, size_t W) { + Pooling::Param param; + param.format = param::Pooling::Format::NCHW88; + param.window_h = WH; + param.window_w = WW; + param.pad_h = PH; + param.pad_w = PW; + param.stride_w = SW; + param.stride_h = SH; + param.mode = param::Pooling::Mode::MAX; + checker.set_param(param); + checker.execs({{N, C, H, W, 8}, {}}); + }; + + for (size_t wh = 10; wh < 15; ++wh) { + for (size_t ww = 10; ww < 15; ++ww) { + for (size_t n : {1, 2, 4}) { + for (size_t c : {1, 4}) { + for (size_t h : {10, 13, 20}) { + for (size_t w : {10, 13, 20}) { + run(wh, ww, wh / 2, ww / 2, 1, 1, n, c, h, w); + } + } + } + } + } + } +} + +TEST_F(X86_MULTI_THREADS, S1POOLING88) { + Checker checker(handle()); + auto run = [&](size_t WH, size_t WW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t N, size_t C, size_t H, size_t W) { + Pooling::Param param; + param.format = param::Pooling::Format::NCHW88; + param.window_h = WH; + param.window_w = WW; + param.pad_h = PH; + param.pad_w = PW; + param.stride_w = SW; + param.stride_h = SH; + param.mode = param::Pooling::Mode::MAX; + checker.set_param(param); + checker.execs({{N, C, H, W, 8}, {}}); + }; + + for (size_t wh = 10; wh < 15; ++wh) { + for (size_t ww = 10; ww < 15; ++ww) { + for (size_t n : {1, 2, 4}) { + for (size_t c : {1, 4}) { + for (size_t h : {10, 13, 20}) { + for (size_t w : {10, 13, 20}) { + run(wh, ww, wh / 2, ww / 2, 1, 1, n, c, h, w); + } + } + } + } + } + } +} + #if MEGDNN_X86_WITH_MKL_DNN TEST_F(X86, POOLING88) { Checker checker(handle()); @@ -104,6 +168,42 @@ TEST_F(X86, BENCHMARK_POOLING) { TEST_F(X86_MULTI_THREADS, BENCHMARK_POOLING) { test_x86_megdnn_pooling(handle()); } +TEST_F(X86, BENCHMARK_POOLING_MAX_S1_NCHW88) { + constexpr size_t RUNS = 50; + auto x86_handle = handle(); + Benchmarker benchmarker_pooling(x86_handle); + benchmarker_pooling.set_times(RUNS); + auto run = [&](uint32_t pad, uint32_t stride, uint32_t window_size, + size_t in_number, size_t in_channel, size_t in_height, + size_t in_width) { + auto opr = x86_handle->create_operator(); + opr->param() = {param::Pooling::Mode::MAX, + pad, + pad, + stride, + stride, + window_size, + window_size}; + opr->param().format = param::Pooling::Format::NCHW88; + + TensorShape shape{in_number, in_channel / 8, in_height, in_width, 8}; + TensorLayout dst_layout; + opr->deduce_layout({shape, dtype::Float32()}, dst_layout); + float computation = + dst_layout.total_nr_elems() * window_size * window_size * 1e-9; + + auto pooling_used = benchmarker_pooling.set_param(opr->param()) + .exec(TensorShapeArray{shape, {}}) / + RUNS; + float through_put = computation / pooling_used * 1e3; + printf("profiling max pooling NCHW88 {%zu,%zu,%zu,%zu,8}\nuse time : " + "%f ms\nthrough_put : %f Gflops\n", + in_number, in_channel / 8, in_height, in_width, pooling_used, + through_put); + }; + run(6, 1, 13, 1, 32 * 8, 20, 20); +} + #endif #if MEGDNN_X86_WITH_MKL_DNN TEST_F(X86, POOLING_INT8) { -- GitLab