提交 10bcf757 编写于 作者: M Megvii Engine Team

feat(dnn/x86): add algo for x86 max pooling for Window size bigger than 10 and S1 under NCHW88

GitOrigin-RevId: 613a18dd916f575fcd29448c2048dde9dae70d1e
上级 ddba5c96
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "src/x86/pooling/pooling_special_cases.h" #include "src/x86/pooling/pooling_special_cases.h"
#include "src/x86/utils.h" #include "src/x86/utils.h"
#include "src/x86/avx_helper.h"
using namespace megdnn; using namespace megdnn;
using namespace x86; using namespace x86;
...@@ -65,6 +67,7 @@ PoolingImpl::AlgoPack::AlgoPack() { ...@@ -65,6 +67,7 @@ PoolingImpl::AlgoPack::AlgoPack() {
all_algos.push_back(&algo_mean_w2s2_sse3); all_algos.push_back(&algo_mean_w2s2_sse3);
all_algos.push_back(&algo_max_w2s2_sse); all_algos.push_back(&algo_max_w2s2_sse);
all_algos.push_back(&algo_max_w3s3_sse); all_algos.push_back(&algo_max_w3s3_sse);
all_algos.push_back(&algo_max_w13s1_nchw88_avx);
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
all_algos.push_back(&algo_mkldnn_nchw); all_algos.push_back(&algo_mkldnn_nchw);
all_algos.push_back(&algo_mkldnn_nchw88); all_algos.push_back(&algo_mkldnn_nchw88);
...@@ -362,4 +365,136 @@ void PoolingImpl::AlgoMKLDNNNCHW88::exec(const ExecArgs& args) const { ...@@ -362,4 +365,136 @@ void PoolingImpl::AlgoMKLDNNNCHW88::exec(const ExecArgs& args) const {
MEGDNN_DISPATCH_CPU_KERN_OPR(run()); MEGDNN_DISPATCH_CPU_KERN_OPR(run());
} }
#endif #endif
\ No newline at end of file
namespace {
MEGDNN_ATTRIBUTE_TARGET("avx")
void max_pooling_s1_nchw88_avx_kern(const float* src, float* dst, int IH,
int IW, int OH, int OW, int PH, int PW,
int WH, int WW) {
static float min_float = -std::numeric_limits<float>::max();
static int VECSIZE = 8;
__m256 ymm[16];
const float* psrc = src;
float* pdst = dst;
//! deal all rows
for (int row = 0; row < IH; ++row) {
for (int j = 0; j < PW; ++j) {
ymm[j] = _mm256_set1_ps(min_float);
}
int col_end = WW - PW < IW ? WW - PW : IW;
for (int j = 0; j < col_end; ++j) {
ymm[j + PW] = _mm256_loadu_ps(psrc + j * VECSIZE);
}
for (int j = col_end + PW; j < WW; ++j) {
ymm[j] = _mm256_set1_ps(min_float);
}
int col_next = WW - PW;
for (int j = 0; j < OW; ++j) {
for (int i = WW - 2; i >= 0; --i) {
ymm[i] = _mm256_max_ps(ymm[i], ymm[i + 1]);
}
_mm256_storeu_ps(pdst, ymm[0]);
pdst += VECSIZE;
for (int i = 0; i < WW - 1; ++i) {
ymm[i] = ymm[i + 1];
}
if (col_next < IW) {
ymm[WW - 1] = _mm256_loadu_ps(psrc + col_next * VECSIZE);
col_next++;
} else {
ymm[WW - 1] = _mm256_set1_ps(min_float);
}
}
psrc += IW * VECSIZE;
}
//! deal all cols
float* src1 = dst;
for (int col = 0; col < OW; ++col) {
for (int j = 0; j < PH; ++j) {
ymm[j] = _mm256_set1_ps(min_float);
}
int row_end = WH - PH < IH ? WH - PH : IH;
for (int j = 0; j < row_end; ++j) {
ymm[j + PH] = _mm256_loadu_ps(src1 + j * OW * VECSIZE);
}
for (int j = row_end + PH; j < WH; ++j) {
ymm[j] = _mm256_set1_ps(min_float);
}
int row_next = WH - PH;
pdst = src1;
for (int j = 0; j < OH; ++j) {
for (int i = WH - 2; i >= 0; --i) {
ymm[i] = _mm256_max_ps(ymm[i], ymm[i + 1]);
}
_mm256_storeu_ps(pdst, ymm[0]);
pdst += OW * VECSIZE;
for (int i = 0; i < WH - 1; ++i) {
ymm[i] = ymm[i + 1];
}
if (row_next < IH) {
ymm[WH - 1] = _mm256_loadu_ps(src1 + row_next * OW * VECSIZE);
row_next++;
} else {
ymm[WH - 1] = _mm256_set1_ps(min_float);
}
}
src1 += VECSIZE;
}
}
} // namespace
bool PoolingImpl::AlgoMaxS1NCHW88AVX::is_available(const SizeArgs& args) const {
bool is_dtype_ok = args.layout_src.dtype == dtype::Float32();
bool is_mode_ok = args.opr->param().mode == Mode::MAX;
bool is_format_ok = args.opr->param().format == Param::Format::NCHW88;
bool is_shape_ok = args.opr->param().window_h >= 10 &&
args.opr->param().window_h <= 15 &&
args.opr->param().window_w >= 10 &&
args.opr->param().window_w <= 15;
bool is_stride_ok =
args.opr->param().stride_h == 1 && args.opr->param().stride_w == 1;
//! this condition guarantee size of dst's memory is bigger enough because
//! dst's memory will be used as workspace to store intermediate result.
bool is_pad_ok =
args.opr->param().pad_h >= args.opr->param().window_h / 2 &&
args.opr->param().pad_w >= args.opr->param().window_w / 2;
bool is_ins_ok = is_supported(SIMDType::AVX);
return is_dtype_ok && is_mode_ok && is_format_ok && is_shape_ok &&
is_pad_ok && is_stride_ok && is_ins_ok;
}
void PoolingImpl::AlgoMaxS1NCHW88AVX::exec(const ExecArgs& args) const {
auto handle = args.handle;
size_t N = args.layout_src.shape[0];
static size_t VECSIZE = 8;
size_t PH = args.opr->param().pad_h;
size_t PW = args.opr->param().pad_w;
size_t WH = args.opr->param().window_h;
size_t WW = args.opr->param().window_w;
size_t IC = args.layout_src.shape[1];
size_t IH = args.layout_src.shape[2];
size_t IW = args.layout_src.shape[3];
size_t OH = args.layout_dst.shape[2];
size_t OW = args.layout_dst.shape[3];
float* src_ptr = reinterpret_cast<float*>(args.src_tensor->raw_ptr);
float* dst_ptr = reinterpret_cast<float*>(args.dst_tensor->raw_ptr);
auto run = [IC, src_ptr, dst_ptr, IH, IW, OH, OW, PH, PW, WH, WW](
size_t index, size_t) {
size_t n = index / IC;
size_t c = index % IC;
float* src =
src_ptr + n * IH * IW * IC * VECSIZE + IH * IW * c * VECSIZE;
float* dst =
dst_ptr + n * OH * OW * IC * VECSIZE + OH * OW * c * VECSIZE;
max_pooling_s1_nchw88_avx_kern(src, dst, IH, IW, OH, OW, PH, PW, WH,
WW);
};
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN(handle, N * IC, run);
}
...@@ -29,6 +29,7 @@ public: ...@@ -29,6 +29,7 @@ public:
X86_MeanW2S2SSE3, X86_MeanW2S2SSE3,
X86_MaxW2S2SSE, X86_MaxW2S2SSE,
X86_MaxW3S3SSE, X86_MaxW3S3SSE,
X86_MaxS1NCHW88AVX,
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
X86_MKLDNNNCHW, X86_MKLDNNNCHW,
X86_MKLDNNNCHW88, X86_MKLDNNNCHW88,
...@@ -87,11 +88,11 @@ ALGO_IMPL(MeanW2S2AVX) ...@@ -87,11 +88,11 @@ ALGO_IMPL(MeanW2S2AVX)
ALGO_IMPL(MeanW2S2SSE3) ALGO_IMPL(MeanW2S2SSE3)
ALGO_IMPL(MaxW2S2SSE) ALGO_IMPL(MaxW2S2SSE)
ALGO_IMPL(MaxW3S3SSE) ALGO_IMPL(MaxW3S3SSE)
ALGO_IMPL(MaxS1NCHW88AVX)
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
ALGO_IMPL(MKLDNNNCHW) ALGO_IMPL(MKLDNNNCHW)
ALGO_IMPL(MKLDNNNCHW88) ALGO_IMPL(MKLDNNNCHW88)
#endif #endif
#undef ALGO_IMPL #undef ALGO_IMPL
class PoolingImpl::AlgoFallback final : public AlgoBase { class PoolingImpl::AlgoFallback final : public AlgoBase {
...@@ -118,6 +119,7 @@ private: ...@@ -118,6 +119,7 @@ private:
AlgoMKLDNNNCHW algo_mkldnn_nchw; AlgoMKLDNNNCHW algo_mkldnn_nchw;
AlgoMKLDNNNCHW88 algo_mkldnn_nchw88; AlgoMKLDNNNCHW88 algo_mkldnn_nchw88;
#endif #endif
AlgoMaxS1NCHW88AVX algo_max_w13s1_nchw88_avx;
AlgoFallback algo_fallback; AlgoFallback algo_fallback;
public: public:
......
...@@ -21,6 +21,7 @@ private: ...@@ -21,6 +21,7 @@ private:
class AlgoMeanW2S2SSE3; class AlgoMeanW2S2SSE3;
class AlgoMaxW2S2SSE; class AlgoMaxW2S2SSE;
class AlgoMaxW3S3SSE; class AlgoMaxW3S3SSE;
class AlgoMaxS1NCHW88AVX;
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
class AlgoMKLDNNNCHW; class AlgoMKLDNNNCHW;
class AlgoMKLDNNNCHW88; class AlgoMKLDNNNCHW88;
......
...@@ -24,6 +24,70 @@ TEST_F(X86, POOLING) { ...@@ -24,6 +24,70 @@ TEST_F(X86, POOLING) {
} }
} }
TEST_F(X86, S1POOLING88) {
Checker<Pooling> checker(handle());
auto run = [&](size_t WH, size_t WW, size_t PH, size_t PW, size_t SH,
size_t SW, size_t N, size_t C, size_t H, size_t W) {
Pooling::Param param;
param.format = param::Pooling::Format::NCHW88;
param.window_h = WH;
param.window_w = WW;
param.pad_h = PH;
param.pad_w = PW;
param.stride_w = SW;
param.stride_h = SH;
param.mode = param::Pooling::Mode::MAX;
checker.set_param(param);
checker.execs({{N, C, H, W, 8}, {}});
};
for (size_t wh = 10; wh < 15; ++wh) {
for (size_t ww = 10; ww < 15; ++ww) {
for (size_t n : {1, 2, 4}) {
for (size_t c : {1, 4}) {
for (size_t h : {10, 13, 20}) {
for (size_t w : {10, 13, 20}) {
run(wh, ww, wh / 2, ww / 2, 1, 1, n, c, h, w);
}
}
}
}
}
}
}
TEST_F(X86_MULTI_THREADS, S1POOLING88) {
Checker<Pooling> checker(handle());
auto run = [&](size_t WH, size_t WW, size_t PH, size_t PW, size_t SH,
size_t SW, size_t N, size_t C, size_t H, size_t W) {
Pooling::Param param;
param.format = param::Pooling::Format::NCHW88;
param.window_h = WH;
param.window_w = WW;
param.pad_h = PH;
param.pad_w = PW;
param.stride_w = SW;
param.stride_h = SH;
param.mode = param::Pooling::Mode::MAX;
checker.set_param(param);
checker.execs({{N, C, H, W, 8}, {}});
};
for (size_t wh = 10; wh < 15; ++wh) {
for (size_t ww = 10; ww < 15; ++ww) {
for (size_t n : {1, 2, 4}) {
for (size_t c : {1, 4}) {
for (size_t h : {10, 13, 20}) {
for (size_t w : {10, 13, 20}) {
run(wh, ww, wh / 2, ww / 2, 1, 1, n, c, h, w);
}
}
}
}
}
}
}
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
TEST_F(X86, POOLING88) { TEST_F(X86, POOLING88) {
Checker<Pooling> checker(handle()); Checker<Pooling> checker(handle());
...@@ -104,6 +168,42 @@ TEST_F(X86, BENCHMARK_POOLING) { ...@@ -104,6 +168,42 @@ TEST_F(X86, BENCHMARK_POOLING) {
TEST_F(X86_MULTI_THREADS, BENCHMARK_POOLING) { TEST_F(X86_MULTI_THREADS, BENCHMARK_POOLING) {
test_x86_megdnn_pooling(handle()); test_x86_megdnn_pooling(handle());
} }
TEST_F(X86, BENCHMARK_POOLING_MAX_S1_NCHW88) {
constexpr size_t RUNS = 50;
auto x86_handle = handle();
Benchmarker<Pooling> benchmarker_pooling(x86_handle);
benchmarker_pooling.set_times(RUNS);
auto run = [&](uint32_t pad, uint32_t stride, uint32_t window_size,
size_t in_number, size_t in_channel, size_t in_height,
size_t in_width) {
auto opr = x86_handle->create_operator<Pooling>();
opr->param() = {param::Pooling::Mode::MAX,
pad,
pad,
stride,
stride,
window_size,
window_size};
opr->param().format = param::Pooling::Format::NCHW88;
TensorShape shape{in_number, in_channel / 8, in_height, in_width, 8};
TensorLayout dst_layout;
opr->deduce_layout({shape, dtype::Float32()}, dst_layout);
float computation =
dst_layout.total_nr_elems() * window_size * window_size * 1e-9;
auto pooling_used = benchmarker_pooling.set_param(opr->param())
.exec(TensorShapeArray{shape, {}}) /
RUNS;
float through_put = computation / pooling_used * 1e3;
printf("profiling max pooling NCHW88 {%zu,%zu,%zu,%zu,8}\nuse time : "
"%f ms\nthrough_put : %f Gflops\n",
in_number, in_channel / 8, in_height, in_width, pooling_used,
through_put);
};
run(6, 1, 13, 1, 32 * 8, 20, 20);
}
#endif #endif
#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
TEST_F(X86, POOLING_INT8) { TEST_F(X86, POOLING_INT8) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册