提交 ab401aba 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(dnn/arm): add arm nchw44 filter2x2 strdie1 and stride2 max pooling

GitOrigin-RevId: 42d144a8139de203d87f1d5753487e1020b14dca
上级 b336db65
......@@ -11,9 +11,10 @@
*/
#include "src/arm_common/pooling/algo.h"
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_int8.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s2x2_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_3x3_s1x1_nchw44.h"
#include "src/arm_common/pooling/do_max_pooling_w2x2_s2x2.h"
#include "src/arm_common/pooling/do_max_pooling_w4x4_s2x2.h"
......@@ -666,6 +667,75 @@ void PoolingImpl::AlgoFilter3MaxStride1NCHW44::exec(
#undef DISPATCH_FUNC
}
bool PoolingImpl::AlgoFilter2MaxStridexNCHW44::usable(
const PoolingKernSizeParam& param) const {
auto SH = param.stride[0];
auto SW = param.stride[1];
auto FH = param.filter[0];
auto FW = param.filter[1];
auto PH = param.padding[0];
auto PW = param.padding[1];
bool avaible = param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.format == Param::Format::NCHW44 &&
param.mode == Mode::MAX && FH == 2 && FW == 2 && SH == SW &&
(SW == 1 || SW == 2) && PH == 0 && PW == 0;
return avaible;
}
void PoolingImpl::AlgoFilter2MaxStridexNCHW44::exec(
const PoolingKernParam& param) const {
auto IH = param.isz[0], IW = param.isz[1];
auto OH = param.osz[0], OW = param.osz[1];
auto N = param.n, C = param.ic;
auto PH = param.padding[0];
auto PW = param.padding[1];
auto SW = param.stride[0];
void* src_ptr = param.src_ptr;
void* dst_ptr = param.dst_ptr;
#define DISPATCH_FUNC(type, func, midout_type_id, i) \
MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), \
midout_iv(midout_type_id)) { \
auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr]( \
size_t index, size_t thread_id) { \
MEGDNN_MARK_USED_VAR(thread_id); \
size_t n = index / C; \
size_t c = index % C; \
do_max_pooling_2x2_stride##i##_##func##_nchw44_NEON( \
static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \
c * IH * IW * 4, \
static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \
c * OH * OW * 4, \
IH, IW, OH, OW, PH, PW); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, \
run); \
} \
MIDOUT_END();
#define DISPATCH_STRIDE(type, func, midout_type_id) \
switch (SW) { \
case 1: { \
DISPATCH_FUNC(type, func, midout_type_id, 1); \
break; \
} \
case 2: { \
DISPATCH_FUNC(type, func, midout_type_id, 2); \
break; \
} \
default: \
megdnn_assert(0, "unsupport stride size"); \
}
DISPATCH_STRIDE(int8_t, int8, 10);
#undef DISPATCH_STRIDE
#undef DISPATCH_FUNC
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......
......@@ -99,6 +99,14 @@ public:
void exec(const PoolingKernParam& param) const override;
};
class PoolingImpl::AlgoFilter2MaxStridexNCHW44 final : public AlgoBase {
public:
bool is_reproducible() const override { return true; }
const char* name() const override { return "ARM_POOLING_FILTER2_MAX_STRIDEX_NCHW44"; }
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
};
WorkspaceBundle get_bundle(const PoolingImpl::PoolingKernSizeParam& param);
} // namespace arm_common
......
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/arm_common/pooling/do_max_pooling_2x2_nchw44.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h"
namespace megdnn {
namespace arm_common {
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW) {
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int8x16_t src0123 = vld1q_s8(sptr0);
int8x16_t src1234 = vld1q_s8(sptr0 + 4);
int8x16_t max0 = vmaxq_s8(src0123, src1234);
src0123 = vld1q_s8(sptr1);
src1234 = vld1q_s8(sptr1 + 4);
int8x16_t max1 = vmaxq_s8(src0123, src1234);
int8x16_t max_out = vmaxq_s8(max0, max1);
vst1q_s8(dptr, max_out);
sptr0 += 16;
sptr1 += 16;
dptr += 16;
}
for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src012 = vld1_s8(sptr0 + 4);
int8x8_t src101 = vld1_s8(sptr1);
int8x8_t src112 = vld1_s8(sptr1 + 4);
int8x8_t max01_tmp = vmax_s8(src001, src101);
int8x8_t max12_tmp = vmax_s8(src012, src112);
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
sptr0 += 4;
sptr1 += 4;
dptr += 4;
}
}
}
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW) {
size_t oh = 0;
for (; oh < OH; ++oh) {
size_t ih = oh << 1;
const int8_t* __restrict sptr0 = src + (ih + 0) * IW * 4;
const int8_t* __restrict sptr1 = src + (ih + 1) * IW * 4;
int8_t* __restrict dptr = dst + oh * OW * 4;
size_t ow = 0;
for (; ow + 3 < OW; ow += 4) {
int8x16_t src00 = vld1q_s8(sptr0);
int8x16_t src04 = vld1q_s8(sptr0 + 4 * 4);
int32x4x2_t src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00),
vreinterpretq_s32_s8(src04));
int32x4_t src0246 = src_tmp.val[0];
int32x4_t src1357 = src_tmp.val[1];
int8x16_t max0 = vmaxq_s8(vreinterpretq_s8_s32(src0246),
vreinterpretq_s8_s32(src1357));
src00 = vld1q_s8(sptr1);
src04 = vld1q_s8(sptr1 + 4 * 4);
src_tmp = vuzpq_s32(vreinterpretq_s32_s8(src00),
vreinterpretq_s32_s8(src04));
src0246 = src_tmp.val[0];
src1357 = src_tmp.val[1];
int8x16_t max1 = vmaxq_s8(vreinterpretq_s8_s32(src0246),
vreinterpretq_s8_s32(src1357));
int8x16_t max_out = vmaxq_s8(max0, max1);
vst1q_s8(dptr, max_out);
sptr0 += 32;
sptr1 += 32;
dptr += 16;
}
for (; ow < OW; ++ow) {
int8x8_t src001 = vld1_s8(sptr0);
int8x8_t src012 = vld1_s8(sptr0 + 4);
int8x8_t src101 = vld1_s8(sptr1);
int8x8_t src112 = vld1_s8(sptr1 + 4);
int8x8_t max01_tmp = vmax_s8(src001, src101);
int8x8_t max12_tmp = vmax_s8(src012, src112);
int8x8_t mat_out = vmax_s8(max01_tmp, max12_tmp);
#define store(i) *(dptr + i) = mat_out[i];
UNROLL_CALL_NOWRAPPER(4, store)
#undef store
sptr0 += 8;
sptr1 += 8;
dptr += 4;
}
}
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/arm_common/pooling/do_max_pooling_2x2_nchw44.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "src/common/utils.h"
namespace megdnn {
namespace arm_common {
void do_max_pooling_2x2_stride1_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW);
void do_max_pooling_2x2_stride2_int8_nchw44_NEON(const int8_t* src, int8_t* dst,
size_t IH, size_t IW,
size_t OH, size_t OW,
size_t PH, size_t PW);
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -27,6 +27,7 @@ class PoolingImpl::AlgoPack : NonCopyableObj {
AlgoInt8Filter3MaxStride2 algo_int8_filter3_max_stride2;
AlgoFilter3MaxStride2NCHW44 algo_filter3_max_stride2_nchw4;
AlgoFilter3MaxStride1NCHW44 algo_filter3_max_stride1_nchw4;
AlgoFilter2MaxStridexNCHW44 algo_filter2_max_stridex_nchw4;
public:
AlgoPack() {
......@@ -40,6 +41,7 @@ public:
all_algos.emplace_back(&algo_int8_filter3_max_stride2);
all_algos.emplace_back(&algo_filter3_max_stride2_nchw4);
all_algos.emplace_back(&algo_filter3_max_stride1_nchw4);
all_algos.emplace_back(&algo_filter2_max_stridex_nchw4);
}
SmallVector<AlgoBase*> all_algos;
};
......
......@@ -85,6 +85,7 @@ private:
class AlgoInt8Filter3MaxStride2;
class AlgoFilter3MaxStride2NCHW44;
class AlgoFilter3MaxStride1NCHW44;
class AlgoFilter2MaxStridexNCHW44;
class AlgoPack;
};
} // namespace arm_common
......
......@@ -154,6 +154,57 @@ TEST_F(ARM_COMMON, POOLING_MAX_W3x3_S1x1_NCHW44)
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S1x1_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON, POOLING_MAX_W2x2_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 2 && iw+2*pw >= 2)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON, POOLING_FP16) {
Checker<Pooling> checker(handle());
......
......@@ -104,6 +104,57 @@ TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W3x3_S1x1_NCHW44)
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S1x1_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 1;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_MAX_W2x2_S2x2_NCHW44)
{
// clang-format off
for (size_t ih: {2, 5, 10, 17})
for (size_t iw: {2, 6, 8, 16, 26})
for (size_t ph: {0})
for (size_t pw: {0})
if (ih+2*ph >= 3 && iw+2*pw >= 3)
{
UniformIntRNG rng{INT8_MIN >> 1, INT8_MAX >> 1};
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::QuantizedS8(1.1f));
checker.set_rng(0,&rng);
param::Pooling param;
param.mode = param::Pooling::Mode::MAX;
param.format = param::Pooling::Format::NCHW44;
param.pad_h = ph;
param.pad_w = pw;
param.stride_h = param.stride_w = 2;
param.window_h = param.window_w = 2;
checker.set_param(param).exec(TensorShapeArray{{2, 2, ih, iw, 4}, {}});
}
// clang-format on
}
TEST_F(ARM_COMMON_MULTI_THREADS, POOLING_INT8_W3x3_S2x2)
{
for (size_t ih: {2, 3, 7, 13, 52, 53, 54, 55})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册