提交 a1cbd9bb 编写于 作者: M Megvii Engine Team

feat(dnn): add fp16 nchw88 pooling algo

GitOrigin-RevId: 7a5e9c7df242fd5d7d7811b1af9213e58be20e91
上级 951cc3b0
......@@ -124,6 +124,19 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_Filter5ModexStridexNCHW44)
};
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class PoolingImpl::AlgoFilterxModexStridexNCHW88 final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
const char* name() const override {
return "ARM_POOLING_FILTERX_MODEX_STRIDEX_NCHW88";
}
bool usable(const PoolingKernSizeParam& param) const override;
void exec(const PoolingKernParam& param) const override;
MEGDNN_DECL_ALGO_TYPE(ARM_Fp16FilterxModexStridexNCHW88)
};
#endif
class PoolingImpl::AlgoFallback final : public AlgoBase {
public:
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; };
......
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#include "midout.h"
#include "src/arm_common/pooling/algo.h"
#include "src/arm_common/pooling/fp16/kern_fp16_nchw88_pooling.h"
MIDOUT_DECL(megdnn_arm_common_fp16_nchw88_pooling)
namespace megdnn {
namespace arm_common {
bool PoolingImpl::AlgoFilterxModexStridexNCHW88::usable(
const PoolingKernSizeParam& param) const {
uint32_t sh = param.stride[0];
uint32_t sw = param.stride[1];
uint32_t fh = param.filter[0];
uint32_t fw = param.filter[1];
bool usable = param.src_type.enumv() == DTypeEnum::Float16 &&
param.format == param::Pooling::Format::NCHW88 &&
(param.mode == PoolingBase::Mode::MAX ||
param.mode == PoolingBase::Mode::AVERAGE) &&
fh == fw && sh == sw;
bool size_ok =
(((fh == 2 || fh == 3 || fh == 4 || fh == 5) && (sh == 1 || sh == 2)) ||
((fh == 9 || fh == 13) && (sh == 1)));
return usable && size_ok;
}
void PoolingImpl::AlgoFilterxModexStridexNCHW88::exec(
const PoolingKernParam& param) const {
int ih = param.isz[0];
int iw = param.isz[1];
int oh = param.osz[0];
int ow = param.osz[1];
int n = param.n;
int ic = param.ic;
int ph = param.padding[0];
int pw = param.padding[1];
int sh = param.stride[0];
int fh = param.filter[0];
auto src = param.src_ptr;
auto dst = param.dst_ptr;
#define DISPATCH_FUNC(filter, stride, mode) \
MIDOUT_BEGIN( \
megdnn_arm_common_fp16_nchw88_pooling, midout_iv(0), \
midout_iv(#filter #stride #mode##_hash)) { \
auto run = [=](size_t index, size_t) { \
const int c_idx = index; \
pooling_fp16_nchw88<filter, stride, mode>( \
static_cast<const __fp16*>(src.get_ptr()) + c_idx * ih * iw * 8, \
static_cast<__fp16*>(dst.get_ptr()) + c_idx * oh * ow * 8, ih, iw, \
oh, ow, ph, pw); \
}; \
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \
static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \
} \
MIDOUT_END();
#define DISPATCH_MODE(filter, stride) \
switch (param.mode) { \
case PoolingBase::Mode::MAX: \
DISPATCH_FUNC(filter, stride, PoolingBase::Mode::MAX); \
break; \
case PoolingBase::Mode::AVERAGE: \
DISPATCH_FUNC(filter, stride, PoolingBase::Mode::AVERAGE); \
break; \
default: \
megdnn_assert(0, "invalid mode %u", static_cast<uint32_t>(param.mode)); \
}
#define DISPATCH_STRIDE(filter) \
switch (sh) { \
case 1: \
DISPATCH_MODE(filter, 1); \
break; \
case 2: \
DISPATCH_MODE(filter, 2); \
break; \
default: \
megdnn_assert( \
0, \
"Invalid stride %d. When the filter size is 2, 3, 4 or 5, stride " \
"can only be 1 or 2.", \
sh); \
}
#define DISPATCH_STRIDE1(filter) \
switch (sh) { \
case 1: \
DISPATCH_MODE(filter, 1); \
break; \
default: \
megdnn_assert( \
0, \
"Invalid stride %d. When the filter size is 9 or 13, stride " \
"can only be 1.", \
sh); \
}
#define DISPATCH_FILTER() \
switch (fh) { \
case 2: \
DISPATCH_STRIDE(2); \
break; \
case 3: \
DISPATCH_STRIDE(3); \
break; \
case 4: \
DISPATCH_STRIDE(4); \
break; \
case 5: \
DISPATCH_STRIDE(5); \
break; \
case 9: \
DISPATCH_STRIDE1(9); \
break; \
case 13: \
DISPATCH_STRIDE1(13); \
break; \
}
DISPATCH_FILTER();
}
} // namespace arm_common
} // namespace megdnn
#endif
\ No newline at end of file
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#pragma once
#include <arm_neon.h>
#include <limits>
#include "megdnn/opr_param_defs.h"
#include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/pooling/opr_impl.h"
#include "src/common/unroll_macro.h"
namespace megdnn {
namespace arm_common {
namespace {
#if MEGDNN_AARCH64
#define OW_STEP 4
#else
#define OW_STEP 2
#endif
template <
int filter, int stride, int ow_step, PoolingBase::Mode mode, typename T1,
typename T2>
struct CalXsXNchw44 {
template <int filter, int stride, PoolingBase::Mode mode, typename T1, typename T2>
struct CalXsXNchw88 {
static void impl(T1 result, T2 src);
};
template <
int filter, int stride, int ow_step, PoolingBase::Mode mode, typename T1,
typename T2>
void calculate_xsx_nchw44(T1 result, T2 src) {
CalXsXNchw44<filter, stride, ow_step, mode, T1, T2>::impl(result, src);
};
#define CALCULATE_MAX_CB(step) \
result[0] = vmaxq_f32(result[0], src[0 * stride + step]); \
result[1] = vmaxq_f32(result[1], src[1 * stride + step]); \
result[2] = vmaxq_f32(result[2], src[2 * stride + step]); \
result[3] = vmaxq_f32(result[3], src[3 * stride + step]);
#define CAL_MAX_CB(step, ow_step) \
result[ow_step] = vmaxq_f16(result[ow_step], src[ow_step * stride + step]);
#define CALCULATE_AVG_CB(step) \
result[0] = vaddq_f32(result[0], src[0 * stride + step]); \
result[1] = vaddq_f32(result[1], src[1 * stride + step]); \
result[2] = vaddq_f32(result[2], src[2 * stride + step]); \
result[3] = vaddq_f32(result[3], src[3 * stride + step]);
#define CAL_AVE_CB(step, ow_step) \
result[ow_step] = vaddq_f16(result[ow_step], src[ow_step * stride + step]);
#define INSTANCE_CAL(filter) \
#define INSTANCE_CAL(filter, ow_step) \
template <int stride, typename T1, typename T2> \
struct CalXsXNchw44<filter, stride, 4, PoolingBase::Mode::MAX, T1, T2> { \
struct CalXsXNchw88<filter, stride, PoolingBase::Mode::MAX, T1, T2> { \
static void impl(T1 result, T2 src) { \
UNROLL_CALL_RAW(filter, CALCULATE_MAX_CB); \
UNROLL_CALL_NOWRAPPER_D2(filter, ow_step, CAL_MAX_CB); \
} \
}; \
template <int stride, typename T1, typename T2> \
struct CalXsXNchw44<filter, stride, 4, PoolingBase::Mode::AVERAGE, T1, T2> { \
struct CalXsXNchw88<filter, stride, PoolingBase::Mode::AVERAGE, T1, T2> { \
static void impl(T1 result, T2 src) { \
UNROLL_CALL_RAW(filter, CALCULATE_AVG_CB); \
UNROLL_CALL_NOWRAPPER_D2(filter, ow_step, CAL_AVE_CB); \
} \
};
INSTANCE_CAL(2)
INSTANCE_CAL(3)
INSTANCE_CAL(4)
INSTANCE_CAL(5)
INSTANCE_CAL(9)
INSTANCE_CAL(13)
INSTANCE_CAL(2, OW_STEP)
INSTANCE_CAL(3, OW_STEP)
INSTANCE_CAL(4, OW_STEP)
INSTANCE_CAL(5, OW_STEP)
INSTANCE_CAL(9, OW_STEP)
INSTANCE_CAL(13, OW_STEP)
#undef INSTANCE_CAL
#undef CALCULATE_AVG_CB
#undef CALCULATE_MAX_CB
#undef CAL_AVE_CB
#undef CAL_MAX_CB
template <int filter, int stride, PoolingBase::Mode mode, typename T1, typename T2>
void calculate_xsx_nchw88(T1 result, T2 src) {
CalXsXNchw88<filter, stride, mode, T1, T2>::impl(result, src);
}
template <int filter, int stride, int ow_step, PoolingBase::Mode mode>
struct KerPoolingFilterXStrideXNchw44 {
static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw);
template <int filter, int stride, PoolingBase::Mode mode>
struct KerPoolingFilterXStrideXNchw88 {
static void impl(const __fp16* src_ptr, __fp16* dst_ptr, size_t iw);
};
template <int filter, int stride, int ow_step>
struct KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, PoolingBase::Mode::MAX> {
static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) {
constexpr int src_reg_size = ow_step * stride + filter - stride;
constexpr int packed_ic = 4;
constexpr int simd_len = 4;
constexpr float default_float = std::numeric_limits<float>::lowest();
float32x4_t result[ow_step];
float32x4_t src[src_reg_size];
template <int filter, int stride>
struct KerPoolingFilterXStrideXNchw88<filter, stride, PoolingBase::Mode::MAX> {
static void impl(const __fp16* src_ptr, __fp16* dst_ptr, size_t iw) {
constexpr int src_reg_size = stride * (OW_STEP - 1) + filter;
constexpr int packed_ic = 8;
constexpr int simd_len = 8;
constexpr dt_float16 min_float16 = std::numeric_limits<dt_float16>::lowest();
result[0] = vdupq_n_f32(default_float);
result[1] = vdupq_n_f32(default_float);
result[2] = vdupq_n_f32(default_float);
result[3] = vdupq_n_f32(default_float);
float16x8_t result[OW_STEP], src[src_reg_size];
#define cb(i) result[i] = vdupq_n_f16(min_float16);
UNROLL_CALL_NOWRAPPER(OW_STEP, cb);
#undef cb
for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + fh_idx * iw * packed_ic, 0);
calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::MAX>(
result, src);
auto src_base_ptr = src_ptr + fh_idx * iw * packed_ic;
rep(i, src_reg_size) { src[i] = vld1q_f16(src_base_ptr + i * simd_len); }
calculate_xsx_nchw88<filter, stride, PoolingBase::Mode::MAX>(result, src);
}
vst1q_f32(dst_ptr + 0 * packed_ic, result[0]);
vst1q_f32(dst_ptr + 1 * packed_ic, result[1]);
vst1q_f32(dst_ptr + 2 * packed_ic, result[2]);
vst1q_f32(dst_ptr + 3 * packed_ic, result[3]);
#define cb(i) vst1q_f16(dst_ptr + i * packed_ic, result[i]);
UNROLL_CALL_NOWRAPPER(OW_STEP, cb)
#undef cb
}
};
template <int filter, int stride, int ow_step>
struct KerPoolingFilterXStrideXNchw44<
filter, stride, ow_step, PoolingBase::Mode::AVERAGE> {
static void impl(const float32_t* src_ptr, float32_t* dst_ptr, size_t iw) {
constexpr int src_reg_size = ow_step * stride + filter - stride;
constexpr int packed_ic = 4;
constexpr int simd_len = 4;
constexpr float default_float = 0;
constexpr float div_filter_size = 1.f / (filter * filter);
const float32x4_t div_filter_size_vec = vdupq_n_f32(div_filter_size);
float32x4_t result[ow_step];
float32x4_t src[src_reg_size];
template <int filter, int stride>
struct KerPoolingFilterXStrideXNchw88<filter, stride, PoolingBase::Mode::AVERAGE> {
static void impl(const __fp16* src_ptr, __fp16* dst_ptr, size_t iw) {
constexpr int src_reg_size = stride * (OW_STEP - 1) + filter;
constexpr int packed_ic = 8;
constexpr int simd_len = 8;
const __fp16 zero = static_cast<__fp16>(0);
const __fp16 div_filter_pow = static_cast<__fp16>(1.0 / (filter * filter));
const float16x8_t div_filter_pow_vec = vdupq_n_f16(div_filter_pow);
result[0] = vdupq_n_f32(default_float);
result[1] = vdupq_n_f32(default_float);
result[2] = vdupq_n_f32(default_float);
result[3] = vdupq_n_f32(default_float);
for (int fh_idx = 0; fh_idx < filter; ++fh_idx) {
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + fh_idx * iw * packed_ic, 0);
calculate_xsx_nchw44<filter, stride, ow_step, PoolingBase::Mode::AVERAGE>(
float16x8_t result[OW_STEP], src[src_reg_size];
#define cb(i) result[i] = vdupq_n_f16(zero);
UNROLL_CALL_NOWRAPPER(OW_STEP, cb)
#undef cb
rep(fh, filter) {
auto src_base_ptr = src_ptr + fh * iw * packed_ic;
rep(i, src_reg_size) { src[i] = vld1q_f16(src_base_ptr + i * simd_len); }
calculate_xsx_nchw88<filter, stride, PoolingBase::Mode::AVERAGE>(
result, src);
}
result[0] = vmulq_f32(result[0], div_filter_size_vec);
result[1] = vmulq_f32(result[1], div_filter_size_vec);
result[2] = vmulq_f32(result[2], div_filter_size_vec);
result[3] = vmulq_f32(result[3], div_filter_size_vec);
vst1q_f32(dst_ptr + 0 * packed_ic, result[0]);
vst1q_f32(dst_ptr + 1 * packed_ic, result[1]);
vst1q_f32(dst_ptr + 2 * packed_ic, result[2]);
vst1q_f32(dst_ptr + 3 * packed_ic, result[3]);
#define cb(i) \
vst1q_f16(dst_ptr + i * simd_len, vmulq_f16(result[i], div_filter_pow_vec));
UNROLL_CALL_NOWRAPPER(OW_STEP, cb)
#undef cb
}
};
template <PoolingBase::Mode mode>
void ker_pooling_nchw44_remain_pad(
const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
void kern_pooling_nchw88_remain_pad(
const __fp16* src, __fp16* dst, const int iw, const int pad_top,
const int pad_left, const int pad_bottom, const int pad_right,
const int filter);
template <>
void ker_pooling_nchw44_remain_pad<PoolingBase::Mode::MAX>(
const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
void kern_pooling_nchw88_remain_pad<PoolingBase::Mode::MAX>(
const __fp16* src, __fp16* dst, const int iw, const int pad_top,
const int pad_left, const int pad_bottom, const int pad_right,
const int filter) {
constexpr int ic_step = 4;
const int ih_end = filter - pad_bottom;
const int iw_end = filter - pad_right;
float32x4_t result = vdupq_n_f32(std::numeric_limits<float>::lowest());
for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) {
for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) {
float32x4_t src = vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step);
result = vmaxq_f32(result, src);
constexpr int ic_step = 8;
const int fh_end = filter - pad_bottom;
const int fw_end = filter - pad_right;
float16x8_t result = vdupq_n_f16(std::numeric_limits<dt_float16>::lowest());
for (int fh_idx = pad_top; fh_idx < fh_end; ++fh_idx) {
for (int fw_idx = pad_left; fw_idx < fw_end; ++fw_idx) {
float16x8_t s = vld1q_f16(src + (fw_idx - pad_left) * ic_step);
result = vmaxq_f16(result, s);
}
src_ptr += iw * ic_step;
src += iw * ic_step;
}
vst1q_f32(dst_ptr, result);
vst1q_f16(dst, result);
}
template <>
void ker_pooling_nchw44_remain_pad<PoolingBase::Mode::AVERAGE>(
const float32_t* src_ptr, float32_t* dst_ptr, const int iw, const int pad_top,
const int pad_bottom, const int pad_left, const int pad_right,
void kern_pooling_nchw88_remain_pad<PoolingBase::Mode::AVERAGE>(
const __fp16* src, __fp16* dst, const int iw, const int pad_top,
const int pad_left, const int pad_bottom, const int pad_right,
const int filter) {
constexpr int ic_step = 4;
const int ih_end = filter - pad_bottom;
const int iw_end = filter - pad_right;
const float div_filter_size = 1.f / (filter * filter);
const float32x4_t div_filter_size_vec = vdupq_n_f32(div_filter_size);
float32x4_t result = vdupq_n_f32(0.f);
for (int ih_idx = pad_top; ih_idx < ih_end; ++ih_idx) {
for (int iw_idx = pad_left; iw_idx < iw_end; ++iw_idx) {
float32x4_t src = vld1q_f32(src_ptr + (iw_idx - pad_left) * ic_step);
result = vaddq_f32(result, src);
constexpr int ic_step = 8;
const int fh_end = filter - pad_bottom;
const int fw_end = filter - pad_right;
float16x8_t result = vdupq_n_f16(static_cast<dt_float16>(0));
float16x8_t div_filter_pow_vec = vdupq_n_f16(1.0 / (filter * filter));
for (int fh_idx = pad_top; fh_idx < fh_end; ++fh_idx) {
for (int fw_idx = pad_left; fw_idx < fw_end; ++fw_idx) {
float16x8_t s = vld1q_f16(src + (fw_idx - pad_left) * ic_step);
result = vaddq_f16(result, s);
}
src_ptr += iw * ic_step;
src += iw * ic_step;
}
result = vmulq_f32(result, div_filter_size_vec);
vst1q_f32(dst_ptr, result);
vst1q_f16(dst, vmulq_f16(result, div_filter_pow_vec));
}
template <PoolingBase::Mode mode>
static inline void kern_pooling_with_pad_nchw44(
const float32_t* src, float32_t* dst, const int filter, const int ow_start,
static inline void kern_pooling_with_pad_nchw88(
const __fp16* src, __fp16* dst, const int filter, const int ow_start,
const int ow_end, const int iw, const int ow, const int stride_w, const int pw,
const int real_ih_idx, const int oh_idx, const int pad_top,
const int pad_bottom) {
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int ic_step = 8;
constexpr int oc_step = 8;
for (int ow_idx = ow_start; ow_idx < ow_end; ++ow_idx) {
const int iw_idx = ow_idx * stride_w;
const int real_iw_idx = std::max(iw_idx - pw, 0);
const int real_iw_idx = std::max(0, iw_idx - pw);
const int pad_left = std::max(0, pw - iw_idx);
const int pad_right = std::max(0, iw_idx - pw + filter - iw);
const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step;
const int dst_offset = (oh_idx * ow + ow_idx) * oc_step;
ker_pooling_nchw44_remain_pad<mode>(
src + src_offset, dst + dst_offset, iw, pad_top, pad_bottom, pad_left,
kern_pooling_nchw88_remain_pad<mode>(
src + src_offset, dst + dst_offset, iw, pad_top, pad_left, pad_bottom,
pad_right, filter);
}
}
template <int filter, int stride, PoolingBase::Mode mode>
static inline void pooling_fp32_nchw44_pad(
const float32_t* src, float32_t* dst, int ih, int iw, int oh, int ow, int ph,
static inline void pooling_fp16_nchw88_pad(
const __fp16* src, __fp16* dst, int ih, int iw, int oh, int ow, int ph,
int pw) {
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int ow_step = 4;
constexpr int ic_step = 8;
constexpr int oc_step = 8;
constexpr int ow_step = OW_STEP;
const int ow_pad_left_end = div_ceil(pw, stride_w);
const int ow_pad_right_end = (iw - filter + pw - 1) / stride_w;
const int ow_pad_right_start = (iw + pw - filter) / stride_w + 1; //!!!! CHECK
const int ow_pad_right_step_end =
(ow_pad_right_end - ow_pad_left_end) / ow_step * ow_step + ow_pad_left_end;
(ow_pad_right_start - ow_pad_left_end) / ow_step * ow_step +
ow_pad_left_end;
rep(oh_idx, oh) {
const int ih_idx = oh_idx * stride_h;
......@@ -218,24 +199,23 @@ static inline void pooling_fp32_nchw44_pad(
const int pad_top = std::max(0, ph - ih_idx);
const int pad_bottom = std::max(0, ih_idx - ph + filter - ih);
if (pad_top > 0 || pad_bottom > 0) {
kern_pooling_with_pad_nchw44<mode>(
kern_pooling_with_pad_nchw88<mode>(
src, dst, filter, 0, ow, iw, ow, stride_w, pw, real_ih_idx, oh_idx,
pad_top, pad_bottom);
} else {
kern_pooling_with_pad_nchw44<mode>(
kern_pooling_with_pad_nchw88<mode>(
src, dst, filter, 0, ow_pad_left_end, iw, ow, stride_w, pw,
real_ih_idx, oh_idx, pad_top, pad_bottom);
real_ih_idx, oh_idx, pad_bottom, pad_bottom);
for (int ow_idx = ow_pad_left_end; ow_idx < ow_pad_right_step_end;
ow_idx += ow_step) {
const int iw_idx = ow_idx * stride_w;
const int real_iw_idx = std::max(iw_idx - pw, 0);
const int real_iw_idx = std::max(0, iw_idx - pw);
const int src_offset = (real_ih_idx * iw + real_iw_idx) * ic_step;
const int dst_offset = (oh_idx * ow + ow_idx) * oc_step;
KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, mode>::impl(
KerPoolingFilterXStrideXNchw88<filter, stride, mode>::impl(
src + src_offset, dst + dst_offset, iw);
}
kern_pooling_with_pad_nchw44<mode>(
kern_pooling_with_pad_nchw88<mode>(
src, dst, filter, ow_pad_right_step_end, ow, iw, ow, stride_w, pw,
real_ih_idx, oh_idx, pad_top, pad_bottom);
}
......@@ -243,29 +223,28 @@ static inline void pooling_fp32_nchw44_pad(
}
template <int filter, int stride, PoolingBase::Mode mode>
static inline void pooling_fp32_nchw44_no_pad(
const float32_t* src, float32_t* dst, int, int iw, int oh, int ow) {
static inline void pooling_fp16_nchw88_no_pad(
const __fp16* src, __fp16* dst, const int iw, const int oh, const int ow) {
constexpr int stride_h = stride;
constexpr int stride_w = stride;
constexpr int ic_step = 4;
constexpr int oc_step = 4;
constexpr int ow_step = 4;
constexpr int ic_step = 8;
constexpr int oc_step = 8;
constexpr int ow_step = OW_STEP;
const int ow_end = ow / ow_step * ow_step;
const int ow_remain = ow - ow_end;
rep(oh_idx, oh) {
const int ih_idx = oh_idx * stride_h;
const int src_ih_offset = ih_idx * iw;
const int dst_oh_offset = oh_idx * ow;
for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) {
const int iw_idx = ow_idx * stride_w;
const int src_offset = (src_ih_offset + iw_idx) * ic_step;
const int dst_offset = (dst_oh_offset + ow_idx) * oc_step;
KerPoolingFilterXStrideXNchw44<filter, stride, ow_step, mode>::impl(
const int src_offset = (ih_idx * iw + iw_idx) * ic_step;
const int dst_offset = (oh_idx * ow + ow_idx) * oc_step;
KerPoolingFilterXStrideXNchw88<filter, stride, mode>::impl(
src + src_offset, dst + dst_offset, iw);
}
if (ow_remain > 0) {
kern_pooling_with_pad_nchw44<mode>(
kern_pooling_with_pad_nchw88<mode>(
src, dst, filter, ow_end, ow, iw, ow, stride_w, 0, ih_idx, oh_idx,
0, 0);
}
......@@ -273,18 +252,18 @@ static inline void pooling_fp32_nchw44_no_pad(
}
template <int filter, int stride, PoolingBase::Mode mode>
static inline void pooling_fp32_nchw44(
const float32_t* src, float32_t* dst, int ih, int iw, int oh, int ow, int ph,
int pw) {
static inline void pooling_fp16_nchw88(
const __fp16* src, __fp16* dst, const int ih, const int iw, const int oh,
const int ow, const int ph, const int pw) {
if (ph > 0 || pw > 0) {
pooling_fp32_nchw44_pad<filter, stride, mode>(src, dst, ih, iw, oh, ow, ph, pw);
pooling_fp16_nchw88_pad<filter, stride, mode>(src, dst, ih, iw, oh, ow, ph, pw);
} else {
pooling_fp32_nchw44_no_pad<filter, stride, mode>(src, dst, ih, iw, oh, ow);
pooling_fp16_nchw88_no_pad<filter, stride, mode>(src, dst, iw, oh, ow);
}
}
#undef OW_STEP
} // namespace
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
#endif
\ No newline at end of file
......@@ -22,6 +22,9 @@ private:
AlgoFilter3ModexStridexNCHW44 algo_filter3_modex_stridex_nchw4;
AlgoFilter4ModexStridexNCHW44 algo_filter4_modex_stridex_nchw4;
AlgoFilter5ModexStridexNCHW44 algo_filter5_modex_stridex_nchw4;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoFilterxModexStridexNCHW88 algo_fp16_filterx_modex_stridex_nchw88;
#endif
AlgoFallback algo_fallback;
public:
......@@ -38,6 +41,9 @@ public:
all_algos.emplace_back(&algo_filter2_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter4_modex_stridex_nchw4);
all_algos.emplace_back(&algo_filter5_modex_stridex_nchw4);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&algo_fp16_filterx_modex_stridex_nchw88);
#endif
all_algos.emplace_back(&algo_fallback);
for (auto&& algo : all_algos) {
......
......@@ -24,6 +24,9 @@ private:
class AlgoFilter3ModexStridexNCHW44;
class AlgoFilter4ModexStridexNCHW44;
class AlgoFilter5ModexStridexNCHW44;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoFilterxModexStridexNCHW88;
#endif
class AlgoFallback;
class AlgoPack;
static AlgoPack sm_algo_pack;
......@@ -56,6 +59,9 @@ public:
ARM_Filter3ModexStridexNCHW44,
ARM_Filter4ModexStridexNCHW44,
ARM_Filter5ModexStridexNCHW44,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
ARM_Fp16FilterxModexStridexNCHW88,
#endif
ARM_Fp32ModexStridexNCHW44,
ARM_Fallback
};
......
......@@ -165,6 +165,10 @@
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a) \
cb(8, 4, ##a) cb(8, 5, ##a) cb(8, 6, ##a) cb(8, 7, ##a) cb(8, 8, ##a)
#define UNROLL_RAW_3x2(cb, v0, a...) \
UNROLL_RAW_2x2(cb, v0, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a)
#define UNROLL_RAW_4x2(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(1, 0, ##a) cb(1, 1, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(3, 0, ##a) cb(3, 1, ##a)
......@@ -177,6 +181,19 @@
UNROLL_RAW_5x2(cb, v0, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a)
#define UNROLL_RAW_9x2(cb, v0, a...) \
UNROLL_RAW_6x2(cb, v0, ##a) \
cb(6, 0, ##a) cb(6, 1, ##a) \
cb(7, 0, ##a) cb(7, 1, ##a) \
cb(8, 0, ##a) cb(8, 1, ##a)
#define UNROLL_RAW_13x2(cb, v0, a...) \
UNROLL_RAW_9x2(cb, v0, ##a) \
cb(9, 0, ##a) cb(9, 1, ##a) \
cb(10, 0, ##a) cb(10, 1, ##a) \
cb(11, 0, ##a) cb(11, 1, ##a) \
cb(12, 0, ##a) cb(12, 1, ##a)
#define UNROLL_RAW_4x6(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) cb(0, 4, ##a) cb(0, 5, ##a) \
cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a) cb(1, 4, ##a) cb(1, 5, ##a) \
......@@ -186,6 +203,28 @@
UNROLL_RAW_4x6(cb, v0, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a) cb(4, 4, ##a) cb(4, 5, ##a)
#define UNROLL_RAW_2x4(cb, v0, a...) \
cb(0, 0, ##a) cb(0, 1, ##a) cb(0, 2, ##a) cb(0, 3, ##a) \
cb(1, 0, ##a) cb(1, 1, ##a) cb(1, 2, ##a) cb(1, 3, ##a)
#define UNROLL_RAW_3x4(cb, v0, a...) \
UNROLL_RAW_2x4(cb, v0, ##a) \
cb(2, 0, ##a) cb(2, 1, ##a) cb(2, 2, ##a) cb(2, 3, ##a)
#define UNROLL_RAW_5x4(cb, v0, a...) \
UNROLL_RAW_4x4(cb, v0, ##a) \
cb(4, 0, ##a) cb(4, 1, ##a) cb(4, 2, ##a) cb(4, 3, ##a)
#define UNROLL_RAW_9x4(cb, v0, a...) \
UNROLL_RAW_5x4(cb, v0, ##a) \
cb(5, 0, ##a) cb(5, 1, ##a) cb(5, 2, ##a) cb(5, 3, ##a) \
cb(6, 0, ##a) cb(6, 1, ##a) cb(6, 2, ##a) cb(6, 3, ##a) \
cb(7, 0, ##a) cb(7, 1, ##a) cb(7, 2, ##a) cb(7, 3, ##a) \
cb(8, 0, ##a) cb(8, 1, ##a) cb(8, 2, ##a) cb(8, 3, ##a)
#define UNROLL_RAW_13x4(cb, v0, a...) \
UNROLL_RAW_9x4(cb, v0, ##a) \
cb(9, 0, ##a) cb(9, 1, ##a) cb(9, 2, ##a) cb(9, 3, ##a) \
cb(10, 0, ##a) cb(10, 1, ##a) cb(10, 2, ##a) cb(10, 3, ##a) \
cb(11, 0, ##a) cb(11, 1, ##a) cb(11, 2, ##a) cb(11, 3, ##a) \
cb(12, 0, ##a) cb(12, 1, ##a) cb(12, 2, ##a) cb(12, 3, ##a)
#define UNROLL_CALL0_D2(step, step2, cb, v...) \
UNROLL_RAW_##step##x##step2(cb, 0, ##v)
#define UNROLL_CALL1_D2(step, step2, cb, v...) \
......
......@@ -216,6 +216,48 @@ TEST_F(ARM_COMMON, POOLING_FP16) {
checker.set_param(param).exec({{2, 3, ih, iw}, {}});
}
}
TEST_F(ARM_COMMON, POOLING_FP16_NCHW88) {
Checker<Pooling> checker(handle());
checker.set_dtype(0, dtype::Float16{});
checker.set_dtype(1, dtype::Float16{});
checker.set_dtype(2, dtype::Float16{});
checker.set_dtype(4, dtype::Float16{});
checker.set_epsilon(0.003);
for (size_t ic : {1, 2, 3, 5, 7, 11})
for (size_t ih : {20, 15})
for (size_t iw : {15, 20, 27, 51, 76, 101, 256})
for (size_t pad : {2, 3, 5})
for (auto mode :
{param::Pooling::Mode::MAX, param::Pooling::Mode::AVERAGE}) {
param::Pooling param;
param.mode = mode;
param.format = param::Pooling::Format::NCHW88;
param.pad_h = pad;
param.pad_w = pad;
for (size_t kernel : {2, 3, 4, 5}) {
if (kernel > pad && ih + 2 * pad >= kernel &&
iw + 2 * pad >= kernel) {
param.window_h = param.window_w = kernel;
param.stride_h = param.stride_w = 1;
checker.set_param(param).exec(
TensorShapeArray{{2, ic, ih, iw, 8}, {}});
param.stride_h = param.stride_w = 2;
checker.set_param(param).exec(
TensorShapeArray{{2, ic, ih, iw, 8}, {}});
}
}
for (size_t kernel : {9, 13}) {
if (kernel > pad && ih + 2 * pad >= kernel &&
iw + 2 * pad >= kernel) {
param.window_h = param.window_w = kernel;
param.stride_h = param.stride_w = 1;
checker.set_param(param).exec(
TensorShapeArray{{2, ic, ih, iw, 8}, {}});
}
}
}
}
#endif
TEST_F(ARM_COMMON, POOLING_QUANTIZED) {
......@@ -367,6 +409,72 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW44_FP32) {
benchmark_nchw44_fp32(handle());
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
void benchmark_nchw88_fp16(Handle* handle) {
using Param = param::Pooling;
auto run = [&](size_t n, size_t c, size_t h, size_t w, size_t filter, size_t stride,
size_t pad, Param::Mode mode) {
Param param;
param.window_h = param.window_w = filter;
param.stride_h = param.stride_w = stride;
param.pad_h = param.pad_w = pad;
param.format = Param::Format::NCHW44;
param.mode = mode;
TensorShape nchw44_shape = {n, c / 4, h, w, 4};
TensorShape nchw88_shape = {n, c / 8, h, w, 8};
TensorLayout dst_layout;
auto opr = handle->create_operator<Pooling>();
opr->param() = param;
opr->deduce_layout({nchw44_shape, dtype::Float32()}, dst_layout);
float calc_amount =
dst_layout.total_nr_elems() * param.window_h * param.window_w;
Benchmarker<Pooling> benchmarker_float16_nchw88(handle);
Benchmarker<Pooling> benchmarker_float32_nchw44(handle);
size_t RUN = 500;
auto t1 = benchmarker_float32_nchw44.set_display(false)
.set_times(RUN)
.set_param(param)
.exec({nchw44_shape, {}});
param.format = Param::Format::NCHW88;
auto t2 = benchmarker_float16_nchw88.set_display(false)
.set_dtype(0, dtype::Float16{})
.set_dtype(1, dtype::Float16{})
.set_dtype(2, dtype::Float16{})
.set_dtype(4, dtype::Float16{})
.set_times(RUN)
.set_param(param)
.exec({nchw88_shape, {}});
printf("{%zu %zu %zu %zu} filter = %zu, stride = %zu pad = %zu\n"
"nchw44_fp32={%.3f ms, %.3f Mflops}, "
"nchw88_fp16={%.3f ms, %.3f Mflops, speed_up %f}\n\n",
n, c, h, w, filter, stride, pad, t1 / RUN,
calc_amount / (t1 / RUN * 1000), t2 / RUN,
calc_amount / (t2 / RUN * 1000), t1 / t2);
};
// Resnet50
run(1, 64, 112, 112, 3, 2, 1, param::Pooling::Mode::MAX);
run(1, 2048, 7, 7, 7, 1, 0, param::Pooling::Mode::AVERAGE);
// VGG16
run(1, 64, 224, 224, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 128, 112, 112, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 256, 56, 56, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 512, 28, 28, 2, 2, 0, param::Pooling::Mode::MAX);
run(1, 512, 14, 14, 2, 2, 0, param::Pooling::Mode::MAX);
}
TEST_F(ARM_COMMON, BENCHMARK_POOLING_NCHW88_FP16) {
benchmark_nchw88_fp16(handle());
}
TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_POOLING_NCHW88_FP16) {
benchmark_nchw88_fp16(handle());
}
#endif
TEST_F(ARM_COMMON, BENCHMARK_POOLING_INT8_W3x3_S2x2) {
using Param = param::Pooling;
auto run = [&](const TensorShapeArray& shapes, Param param) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册