diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index d6aa3b5ede1f657e79bc115debe5f65dbb96b6de..115829891dbe1140d6121731004b227839960f45 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -407,6 +407,11 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern( return kern_mk8_16x12x1; } +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( + AlgoF16MK8_16x12x1, megdnn_aarch64_matmul_kern, "AlogF16MK8_16x12x1Impl"_hash, + aarch64::matmul::hgemm_mk8_16x12, dt_float16, dt_float16, AlgoDataType::FLOAT16, + MK8); + #endif #if MGB_ENABLE_DOT diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 22611c945c4a52ab30ea10de12b78f7b27cdf20d..0d84aaf0bff61d5c927c4c49d995a5c85bfa7843 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -93,7 +93,7 @@ public: bool usable(const KernSizeParam&) const override; size_t get_workspace(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override; - MEGDNN_OVERRIDE_MATMUL_DESC(16, 12, 1, 2, AlgoDataType::FLOAT16, MK8); + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1); }; diff --git a/dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc b/dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc index 17f9638254c288deac72c12d6d43080c2672a663..67c8048673c57ea27fd819722ca26f8f611c39eb 100644 --- a/dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc +++ b/dnn/src/aarch64/matrix_mul/fp16/mk8_16x12_kern.inc @@ -9,8 +9,8 @@ template <> void matmul_mk8_16x12::kern( - const dt_float16* packedA, const dt_float16* packedB, int K, - dt_float16* out, int LDC, bool is_first_k) { + const dt_float16* packedA, const dt_float16* packedB, int K, dt_float16* out, + int LDC, bool is_first_k) { #define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n" #define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n" // clang-format off diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 61d568bc3c612dd3a778dc694e24bc1ecd47995d..cdad66974c69b84cc105fc2638bcb7f05f565e8c 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -26,6 +26,8 @@ static fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param( format = param::MatrixMul::Format::MK4; } else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { format = param::MatrixMul::Format::MK4_DOT; + } else if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { + format = param::MatrixMul::Format::MK8; } size_t M = oc_tile_size; size_t N = ohw_tile_size; @@ -329,9 +331,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( #if MEGDNN_AARCH64 || MEGDNN_ARMV7 if (format != param::ConvBias::Format::NCHW && format != param::ConvBias::Format::NCHW44 && - format != param::ConvBias::Format::NCHW44_DOT) { + format != param::ConvBias::Format::NCHW44_DOT && + format != param::ConvBias::Format::NCHW88) { return false; } + if (format == param::ConvBias::Format::NCHW88) { + if (matmul_desc.packmode != Pack_Mode::DEFAULT) { + return false; + } + } if (format == param::ConvBias::Format::NCHW44 || format == param::ConvBias::Format::NCHW44_DOT) { //! current NCHW44 im2col only support DEFAULT mode matmul diff --git a/dnn/src/fallback/conv_bias/im2col/factory.h b/dnn/src/fallback/conv_bias/im2col/factory.h index f23581f8719e23b33d6a6bb0a018ade36ca7afbe..0d4b05cc8fed68fc6563b12b2ef68a4e0969d666 100644 --- a/dnn/src/fallback/conv_bias/im2col/factory.h +++ b/dnn/src/fallback/conv_bias/im2col/factory.h @@ -248,8 +248,18 @@ public: break; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case StrategyType::FLOAT_FP16: - cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, - "DefaultStrategyType::FLOAT_FP16"_hash); + if (format == param::ConvBias::Format::NCHW) { + cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, + "DefaultStrategyType::FLOAT_FP16"_hash); + } else if (format == param::ConvBias::Format::NCHW88) { + cb1(NCHW88, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT, + "DefaultStrategyTypeNCHW88::FLOAT_FP16"_hash); + } else { + megdnn_throw(ssprintf( + "Current only support layout NCHW/NCHW88 for im2col algo " + "of float 16, but got %d\n", + uint32_t(format))); + } break; #endif #if !MEGDNN_DISABLE_FLOAT16 diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index 7d3475b84dc86826a182ddc63a536736a3135150..8f5b30439c5fa917bd5ca9a91d7bee65b491efc9 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -343,6 +343,32 @@ public: const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; }; +template < + typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, + typename op_dtype, megdnn::PostprocessMode postprocess_mode> +class Strategy< + src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, + PackMode::DEFAULT, FormatMode::NCHW88> + : public Strategy< + src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, + postprocess_mode, PackMode::DEFAULT> { +public: + constexpr static size_t BUNDLE_PADDING_INDEX = 0; + constexpr static size_t BUNDLE_PACKA_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0; + constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1; + constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2; + + Strategy() = default; + + void exec_im2col( + const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam matmul_param, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; +}; + template < typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, typename op_dtype, megdnn::PostprocessMode postprocess_mode> diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw88.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw88.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9466a631c7e335707a8f4253db87887936bd36c2 --- /dev/null +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw88.cpp @@ -0,0 +1,98 @@ +#include "src/fallback/conv_bias/im2col/strategy_base.h" +#include "src/fallback/convolution/img2col_helper.h" +#if MEGDNN_X86 +#include "src/x86/conv_bias/postprocess_helper.h" +#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) +#include "src/arm_common/conv_bias/postprocess_helper.h" +#else +#include "src/common/postprocess_helper.h" +#endif + +using namespace megdnn; +#if MEGDNN_X86 +using namespace x86; +#endif + +namespace megdnn { +template < + typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype, + typename op_dtype, megdnn::PostprocessMode postprocess_mode> +void Strategy< + src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode, + PackMode::DEFAULT, FormatMode::NCHW88>:: + exec_im2col( + const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread, + const StrategyParam& sparam, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernParam matmul_param, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { + size_t sh = param.filter_meta.stride[0]; + size_t sw = param.filter_meta.stride[1]; + size_t ow = param.osz[1]; + size_t ic = param.filter_meta.icpg; + size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; + size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2; + size_t fh = param.filter_meta.spatial[0]; + size_t fw = param.filter_meta.spatial[1]; + bool is_xcoor = !param.filter_meta.should_flip; + constexpr static size_t pack_size = 8; + size_t input_offset = + ic * ih * iw * + (sparam.group_id + param.filter_meta.group * sparam.batch_id) * + sizeof(src_ctype); + src_ctype* src = reinterpret_cast( + reinterpret_cast(bundle.get(BUNDLE_PADDING_INDEX)) + + input_offset); + bool is_phpwzero = + (param.filter_meta.padding[0] == 0 && param.filter_meta.padding[1] == 0); + if (is_phpwzero) { + src = const_cast( + param.src(sparam.batch_id, sparam.group_id)); + } + src_ctype* im2col_dst = + reinterpret_cast(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX)); + if (sh == 1 && sw == 1) { + if (is_xcoor) { + img2col_nchw8( + src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, + sparam.output_block_size); + } else { + img2col_nchw8( + src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index, + sparam.output_block_size); + } + } else { + if (is_xcoor) { + img2col_stride_nchw8( + src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, + sparam.ohw_cur_index, sparam.output_block_size); + } else { + img2col_stride_nchw8( + src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw, + sparam.ohw_cur_index, sparam.output_block_size); + } + } + matmul_param.M = sparam.output_block_oc_size; + matmul_param.N = sparam.output_block_size; + matmul_param.LDB = pack_size * sparam.output_block_size; + matmul_param.LDC = pack_size * sparam.output_block_size; + matmul_param.B_ptr = im2col_dst; + + src_ctype* b_panel = + reinterpret_cast(bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX)); + matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N); +} + +#define INSTANTIAL_CLASS( \ + _src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, _postprocess_mode) \ + template class Strategy< \ + _src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, \ + _postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW88>; + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +INSTANTIAL_CLASS( + dt_float16, dt_float16, dt_float16, __fp16, __fp16, + megdnn::PostprocessMode::FLOAT); +#endif +#undef INSTANTIAL_CLASS +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/fallback/convolution/img2col_helper.h b/dnn/src/fallback/convolution/img2col_helper.h index dfce3b2ae42e7eedb92d312eb20b52e0f6fea6d9..0721052ef58bb3e01e7cd33746bde5d83a22e476 100644 --- a/dnn/src/fallback/convolution/img2col_helper.h +++ b/dnn/src/fallback/convolution/img2col_helper.h @@ -347,6 +347,441 @@ void img2col_nchw4( } } +template +void img2col_nchw8( + const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, + const int IH, const int IW, const int FH, const int FW, const int cur_index, + const int block_size) { + int start_h = cur_index / OW; + int cur_n_remain = cur_index % OW; + int end_h = (cur_index + block_size) / OW; + int end_n_remain = (cur_index + block_size) % OW; + bool same_line = (start_h == end_h); + + int IC_div_8 = IC / 8; + + if (sizeof(dtype) == 2) { + if (same_line) { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + //! TODO: Substitute GI for arm intrinsic when GI supports FP16 + //! data type. + int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + + cur_n_remain + fw2); + for (int w = cur_n_remain; w < end_n_remain; ++w) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + src_idx)); + dst_idx += 8; + src_idx += 8; + } +#else + int src_idx = 2 * (ic * IH * IW + (start_h + fh2) * IW + + cur_n_remain + fw2); + uint64_t* u64_src = reinterpret_cast(src); + uint64_t* u64_dst = reinterpret_cast(dst); + for (int w = cur_n_remain; w < end_n_remain; w++) { + u64_dst[dst_idx] = u64_src[src_idx]; + u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; + dst_idx += 2; + src_idx += 2; + } +#endif + } + } + } + } else { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + int src_idx = 8 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + + cur_n_remain); + for (int w = cur_n_remain; w < OW; ++w) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + src_idx)); + dst_idx += 8; + src_idx += 8; + } + src_idx = 8 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); + for (int h = start_h + 1; h < end_h; ++h) { + int _src_idx = src_idx; + rep(w, OW) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + _src_idx)); + dst_idx += 8; + _src_idx += 8; + } + src_idx += IW * 8; + } + src_idx = 8 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); + rep(w, end_n_remain) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + src_idx)); + dst_idx += 8; + src_idx += 8; + } +#else + uint64_t* u64_src = reinterpret_cast(src); + uint64_t* u64_dst = reinterpret_cast(dst); + int src_idx = 2 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 + + cur_n_remain); + for (int w = cur_n_remain; w < OW; ++w) { + u64_dst[dst_idx] = u64_src[src_idx]; + u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; + dst_idx += 2; + src_idx += 2; + } + src_idx = 2 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2); + for (int h = start_h + 1; h < end_h; ++h) { + int _src_idx = src_idx; + rep(w, OW) { + u64_dst[dst_idx] = u64_src[_src_idx]; + u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; + dst_idx += 2; + _src_idx += 2; + } + src_idx += IW * 2; + } + src_idx = 2 * (ic * IH * IW + (fh2 + end_h) * IW + fw2); + rep(w, end_n_remain) { + u64_dst[dst_idx] = u64_src[src_idx]; + u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; + dst_idx += 2; + src_idx += 2; + } +#endif + } + } + } + } + } else { + if (same_line) { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + + cur_n_remain); + for (int w = cur_n_remain; w < end_n_remain; ++w) { + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + } + } + } + } + } else { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + + int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 + + cur_n_remain); + for (int w = cur_n_remain; w < OW; ++w) { + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + } + + src_idx = 8 * (ic * IH * IW + (start_h + 1 + fh2) * IW + fw2); + for (int h = start_h + 1; h < end_h; ++h) { + rep(w, OW) { + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + } + } + + src_idx = 8 * (ic * IH * IW + (end_h + fh2) * IW + fw2); + rep(w, end_n_remain) { + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + dst[dst_idx++] = src[src_idx++]; + } + } + } + } + } + } +} + +template +void img2col_stride_nchw8( + const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC, + const int IH, const int IW, const int FH, const int FW, const int SH, + const int SW, const int cur_index, const int block_size) { + int start_h = cur_index / OW; + int cur_n_remain = cur_index % OW; + int end_h = (cur_index + block_size) / OW; + int end_n_remain = (cur_index + block_size) % OW; + bool same_line = (start_h == end_h); + + int IC_div_8 = IC / 8; + + if (sizeof(dtype) == 2) { + if (same_line) { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + + cur_n_remain * SW + fw2); + for (int w = cur_n_remain; w < end_n_remain; ++w) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + src_idx)); + dst_idx += 8; + src_idx += 8 * SW; + } +#else + int src_idx = 2 * (ic * IH * IW + (start_h * SH + fh2) * IW + + cur_n_remain * SW + fw2); + uint64_t* u64_src = reinterpret_cast(src); + uint64_t* u64_dst = reinterpret_cast(dst); + for (int w = cur_n_remain; w < end_n_remain; w++) { + u64_dst[dst_idx] = u64_src[src_idx]; + u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; + dst_idx += 2; + src_idx += 2 * SW; + } +#endif + } + } + } + } else { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + int src_idx = 8 * (ic * IH * IW + (fh2 + start_h * SH) * IW + + fw2 + cur_n_remain * SW); + for (int w = cur_n_remain; w < OW; ++w) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + src_idx)); + dst_idx += 8; + src_idx += 8 * SW; + } + src_idx = 8 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + + fw2); + for (int h = start_h + 1; h < end_h; ++h) { + int _src_idx = src_idx; + rep(w, OW) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + _src_idx)); + dst_idx += 8; + _src_idx += 8 * SW; + } + src_idx += IW * 8 * SH; + } + src_idx = 8 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); + rep(w, end_n_remain) { + vst1q_f16( + reinterpret_cast<__fp16*>(dst) + dst_idx, + vld1q_f16( + reinterpret_cast(src) + + src_idx)); + dst_idx += 8; + src_idx += 8 * SW; + } +#else + uint64_t* u64_src = reinterpret_cast(src); + uint64_t* u64_dst = reinterpret_cast(dst); + int src_idx = 2 * (ic * IH * IW + (fh2 + start_h * SH) * IW + + fw2 + cur_n_remain * SW); + for (int w = cur_n_remain; w < OW; ++w) { + u64_dst[dst_idx] = u64_src[src_idx]; + u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; + dst_idx += 2; + src_idx += 2 * SW; + } + src_idx = 2 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW + + fw2); + for (int h = start_h + 1; h < end_h; ++h) { + int _src_idx = src_idx; + rep(w, OW) { + u64_dst[dst_idx] = u64_src[_src_idx]; + u64_dst[dst_idx + 1] = u64_src[_src_idx + 1]; + dst_idx += 2; + _src_idx += 2 * SW; + } + src_idx += IW * 2 * SH; + } + src_idx = 2 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2); + rep(w, end_n_remain) { + u64_dst[dst_idx] = u64_src[src_idx]; + u64_dst[dst_idx + 1] = u64_src[src_idx + 1]; + dst_idx += 2; + src_idx += 2 * SW; + } +#endif + } + } + } + } + } else { + if (same_line) { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + + fw2 + cur_n_remain * SW); + for (int w = cur_n_remain; w < end_n_remain; ++w) { + dst[dst_idx++] = src[src_idx]; + dst[dst_idx++] = src[src_idx + 1]; + dst[dst_idx++] = src[src_idx + 2]; + dst[dst_idx++] = src[src_idx + 3]; + dst[dst_idx++] = src[src_idx + 4]; + dst[dst_idx++] = src[src_idx + 5]; + dst[dst_idx++] = src[src_idx + 6]; + dst[dst_idx++] = src[src_idx + 7]; + src_idx += 8 * SW; + } + } + } + } + } else { + int dst_idx = 0; + rep(ic, IC_div_8) { + rep(fh, FH) { + rep(fw, FW) { + int fh2 = fh, fw2 = fw; + if (!is_xcorr) { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + + int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW + + fw2 + cur_n_remain * SW); + for (int w = cur_n_remain; w < OW; ++w) { + dst[dst_idx++] = src[src_idx]; + dst[dst_idx++] = src[src_idx + 1]; + dst[dst_idx++] = src[src_idx + 2]; + dst[dst_idx++] = src[src_idx + 3]; + dst[dst_idx++] = src[src_idx + 4]; + dst[dst_idx++] = src[src_idx + 5]; + dst[dst_idx++] = src[src_idx + 6]; + dst[dst_idx++] = src[src_idx + 7]; + src_idx += 8 * SW; + } + + src_idx = 8 * (ic * IH * IW + ((start_h + 1) * SH + fh2) * IW + + fw2); + for (int h = start_h + 1; h < end_h; ++h) { + rep(w, OW) { + dst[dst_idx++] = src[src_idx]; + dst[dst_idx++] = src[src_idx + 1]; + dst[dst_idx++] = src[src_idx + 2]; + dst[dst_idx++] = src[src_idx + 3]; + dst[dst_idx++] = src[src_idx + 4]; + dst[dst_idx++] = src[src_idx + 5]; + dst[dst_idx++] = src[src_idx + 6]; + dst[dst_idx++] = src[src_idx + 7]; + src_idx += 8 * SW; + } + } + + src_idx = 8 * (ic * IH * IW + (end_h * SH + fh2) * IW + fw2); + rep(w, end_n_remain) { + dst[dst_idx++] = src[src_idx]; + dst[dst_idx++] = src[src_idx + 1]; + dst[dst_idx++] = src[src_idx + 2]; + dst[dst_idx++] = src[src_idx + 3]; + dst[dst_idx++] = src[src_idx + 4]; + dst[dst_idx++] = src[src_idx + 5]; + dst[dst_idx++] = src[src_idx + 6]; + dst[dst_idx++] = src[src_idx + 7]; + src_idx += 8 * SW; + } + } + } + } + } + } +} + template void img2col_stride( const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH, diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index c3409918708a1464206ee83b546530c82ba37374..48c03314466096ddd09920212664282c7b33cb3f 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -68,6 +68,87 @@ void benchmark_impl( multi_thread_config.nr_thread); } } + +void benchmark_with_contrast( + const std::vector& args, const std::string algo_name, + std::vector& data_type, + const std::vector& args_contrast, + const std::string algo_name_contrast, std::vector& data_type_contrast, + size_t RUNS, TaskExecutorConfig&& single_thread_config) { + auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config); + + auto benchmarker = Benchmarker(single_thread_handle.get()); + auto benchmarker_contrast = Benchmarker(single_thread_handle.get()); + + benchmarker.set_times(RUNS) + .set_display(false) + .set_dtype(0, data_type[0]) + .set_dtype(1, data_type[1]) + .set_dtype(2, data_type[2]) + .set_dtype(4, data_type[3]) + .set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker(algo_name.c_str())); + benchmarker_contrast.set_times(RUNS) + .set_display(false) + .set_dtype(0, data_type_contrast[0]) + .set_dtype(1, data_type_contrast[1]) + .set_dtype(2, data_type_contrast[2]) + .set_dtype(4, data_type_contrast[3]) + .set_before_exec_callback(conv_bias::ConvBiasAlgoChecker( + algo_name_contrast.c_str())); + + size_t arg_size = args.size(), arg_contrast_size = args_contrast.size(); + megdnn_assert(arg_size == arg_contrast_size); + rep(i, arg_size) { + TensorLayout dst_layout, dst_layout_contrast; + auto opr = single_thread_handle.get()->create_operator(); + + auto&& arg = args[i]; + opr->param() = arg.param; + opr->deduce_layout( + {arg.src, data_type[0]}, {arg.filter, data_type[1]}, + {arg.bias, data_type[2]}, {}, dst_layout); + float computation = (dst_layout.total_nr_elems() * arg.filter[1] * + arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) / + (1024 * 1024 * 1024) * 1e3; + benchmarker.set_param(arg.param); + auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS; + + auto&& arg_contrast = args_contrast[i]; + opr->param() = arg_contrast.param; + opr->deduce_layout( + {arg_contrast.src, data_type_contrast[0]}, + {arg_contrast.filter, data_type_contrast[1]}, + {arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast); + float computation_contrast = + (dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] * + arg_contrast.filter[2] * arg_contrast.filter[3] * + arg_contrast.filter[4] * 2.0) / + (1024 * 1024 * 1024) * 1e3; + benchmarker_contrast.set_param(arg_contrast.param); + auto used_contrast = benchmarker_contrast.exec( + {arg_contrast.src, + arg_contrast.filter, + arg_contrast.bias, + {}, + {}}) / + RUNS; + + printf("Bench case: \n"); + printf("padding: %u, stride: %u, nonline mode: %u\n", arg.param.pad_h, + arg.param.stride_h, arg.param.nonlineMode); + printf("%s %s %s\n", arg.src.to_string().c_str(), + arg.filter.to_string().c_str(), arg.bias.to_string().c_str()); + printf("%s %s %s\n", arg_contrast.src.to_string().c_str(), + arg_contrast.filter.to_string().c_str(), + arg_contrast.bias.to_string().c_str()); + + printf("%s: %f gflops;\n%s: %f gflops\n" + "spead up = %f\n", + algo_name.c_str(), computation / used, algo_name_contrast.c_str(), + computation_contrast / used_contrast, used_contrast / used); + } +} } // namespace #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -1591,6 +1672,91 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_FP32) { data_type); shapes_and_computation.clear(); } + +TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_NCHW88) { + constexpr size_t RUNS = 50; + using NLMode = param::ConvBias::NonlineMode; + + std::vector args_nchw88, args_nchw44; + auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS, + size_t group) { + param::ConvBias param_nchw88, param_nchw44; + param_nchw88.format = param::ConvBias::Format::NCHW88; + param_nchw44.format = param::ConvBias::Format::NCHW44; + for (size_t pad : {1, 2, 4}) { + for (size_t stride : {1, 2, 3}) { + for (auto nlmode : + {NLMode::RELU, NLMode::IDENTITY, NLMode::SIGMOID, + NLMode::H_SWISH}) { + param_nchw88.nonlineMode = nlmode; + param_nchw88.pad_h = pad; + param_nchw88.pad_w = pad; + param_nchw88.stride_h = stride; + param_nchw88.stride_w = stride; + + param_nchw44.nonlineMode = nlmode; + param_nchw44.pad_h = pad; + param_nchw44.pad_w = pad; + param_nchw44.stride_h = stride; + param_nchw44.stride_w = stride; + + args_nchw88.emplace_back( + param_nchw88, TensorShape{N, IC / 8, H, W, 8}, + TensorShape{OC / 8, IC / group / 8, FS, FS, 8, 8}, + TensorShape{1, OC / 8, 1, 1, 8}); + args_nchw44.emplace_back( + param_nchw44, TensorShape{N, IC / 4, H, W, 4}, + TensorShape{OC / 4, IC / group / 4, FS, FS, 4, 4}, + TensorShape{1, OC / 4, 1, 1, 4}); + } + } + } + }; + std::vector data_type_fp16 = { + dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()}; + std::vector data_type_fp32 = { + dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()}; + bench_case(1, 32, 32, 300, 300, 3, 1); + bench_case(1, 32, 32, 400, 400, 3, 1); + bench_case(1, 32, 32, 100, 100, 3, 1); + bench_case(1, 32, 32, 80, 80, 3, 1); + bench_case(1, 32, 64, 200, 200, 3, 1); + bench_case(1, 32, 64, 128, 128, 3, 1); + bench_case(1, 32, 64, 100, 100, 3, 1); + bench_case(1, 32, 64, 80, 80, 3, 1); + bench_case(1, 32, 128, 200, 200, 3, 1); + bench_case(1, 32, 128, 128, 128, 3, 1); + bench_case(1, 32, 128, 100, 100, 3, 1); + bench_case(1, 32, 128, 80, 80, 3, 1); + + bench_case(1, 64, 32, 7, 7, 3, 1); + bench_case(1, 64, 64, 7, 7, 3, 1); + bench_case(1, 64, 128, 7, 7, 3, 1); + bench_case(1, 64, 256, 7, 7, 3, 1); + bench_case(1, 64, 512, 7, 7, 3, 1); + bench_case(1, 64, 1024, 7, 7, 3, 1); + + bench_case(1, 64, 32, 14, 14, 3, 1); + bench_case(1, 64, 64, 14, 14, 3, 1); + bench_case(1, 64, 128, 14, 14, 3, 1); + bench_case(1, 64, 256, 14, 14, 3, 1); + bench_case(1, 64, 512, 14, 14, 3, 1); + + bench_case(1, 64, 1024, 14, 14, 3, 1); + bench_case(1, 128, 128, 14, 14, 3, 1); + bench_case(1, 128, 256, 14, 14, 3, 1); + bench_case(1, 512, 512, 14, 14, 3, 1); + bench_case(1, 256, 512, 14, 14, 3, 1); + bench_case(1, 512, 1024, 14, 14, 3, 1); + bench_case(1, 1024, 1024, 14, 14, 3, 1); + std::string algo_name_nchw88 = "IM2COLMATMUL:AARCH64_F16_MK8_16X12X1:96"; + std::string algo_name_nchw44 = "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1:96"; + + benchmark_with_contrast( + args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44, + algo_name_nchw44, data_type_fp32, RUNS, {1, {4}}); +} + TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) { constexpr size_t RUNS = 50; diff --git a/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp b/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp index 57562db2b6a8bd0a2379d2edaee6590d051043ae..abd7310727784b0a5b969da2a5f8a372387e1a84 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp @@ -362,6 +362,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { #endif #undef cb } + +TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_MK8_FP16) { + using namespace conv_bias; + + std::vector args = get_nchw88_conv_bias_args( + {2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1); + auto args1 = get_nchw88_conv_bias_args( + {2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 2, 3); + args.insert(args.begin(), args1.begin(), args1.begin()); + args1 = get_nchw88_conv_bias_args( + {2, 3, 4, 5, 6, 7, 9}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 3, 4); + args.insert(args.begin(), args1.begin(), args1.begin()); + + NormalRNG rng(1); +#define cb(name) \ + checker_conv_bias_common( \ + args, handle(), &rng, 0.03, dtype::Float16{}, dtype::Float16{}, \ + dtype::Float16{}, dtype::Float16{}, name); + +#if MEGDNN_AARCH64 + cb("IM2COLMATMUL:AARCH64_F16_MK8_16X12X1"); +#endif +#undef cb +} #endif #if MEGDNN_AARCH64 || MEGDNN_ARMV7 diff --git a/dnn/test/arm_common/elemwise.cpp b/dnn/test/arm_common/elemwise.cpp index f192f17808f3704b969ec7e426e84aa7ebaeee26..6563fd2331f1eb637ba0baf923aee2affb1fa0c4 100644 --- a/dnn/test/arm_common/elemwise.cpp +++ b/dnn/test/arm_common/elemwise.cpp @@ -161,6 +161,24 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) { run(); } +TEST_F(ARM_COMMON, ELEMWISE_SIGMOID) { + using Mode = ElemwiseForward::Param::Mode; + Checker checker(handle()); + + checker.set_epsilon(1e-3); + checker.set_dtype(0, dtype::Float16()); + checker.set_param(Mode::SIGMOID); + for (size_t n : {1, 2, 3}) { + for (size_t ic : {8, 16, 24, 32}) { + for (size_t ih : {5, 10, 15, 20, 21, 37}) { + for (size_t iw : {7, 9, 11, 13, 14, 20, 35}) { + checker.exec({{n, ic, ih, iw}, {}}); + } + } + } + } +} + TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) { using Mode = ElemwiseForward::Param::Mode; Checker checker(handle()); diff --git a/dnn/test/arm_common/elemwise_benchmark.cpp b/dnn/test/arm_common/elemwise_benchmark.cpp index a8f06b7032649ddf767e78e13e39b75ca131f42e..5e021784b66968b1ee393dfa7af6cd6da010433d 100644 --- a/dnn/test/arm_common/elemwise_benchmark.cpp +++ b/dnn/test/arm_common/elemwise_benchmark.cpp @@ -98,6 +98,9 @@ TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY) { BENCHMARK_CASES_INT(shape, dtype::Int16()); BENCHMARK_CASES_INT(shape, dtype::Int8()); BENCHMARK_CASES_FLOAT(shape, dtype::Float32()); +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + BENCHMARK_CASES_FLOAT(shape, dtype::Float16()); +#endif #undef BENCHMARK_CASES_INT #undef BENCHMARK_CASES_FLOAT #undef RUN diff --git a/dnn/test/common/conv_bias.cpp b/dnn/test/common/conv_bias.cpp index 0130db006de48099689ab6826a90e39877ad4166..a09ebd44224c3a00cb3a40361441c94503472db7 100644 --- a/dnn/test/common/conv_bias.cpp +++ b/dnn/test/common/conv_bias.cpp @@ -1580,17 +1580,19 @@ std::vector get_nchw44_conv_bias_args( std::vector get_nchw88_conv_bias_args( std::vector kernel_vec, std::vector nlmode_vec, - std::vector biasmode_vec, size_t stride) { + std::vector biasmode_vec, size_t stride, int pad) { using namespace conv_bias; using NLMode = param::ConvBias::NonlineMode; std::vector args; auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel, - size_t stride, size_t group, NLMode nlmode, + size_t stride, int pad, size_t group, NLMode nlmode, megdnn::BiasMode bias_mode) { constexpr int pack_c = 8; - const size_t pad = kernel / 2; + if (pad == -1) { + pad = kernel / 2; + } auto oc_per_group = oc / group; auto ic_per_group = ic / group; @@ -1651,8 +1653,8 @@ std::vector get_nchw88_conv_bias_args( if (kernel < h || kernel < w) { continue; } - pack(n, oc, ic, h, w, kernel, stride, group, - nlmode, bias); + pack(n, oc, ic, h, w, kernel, stride, pad, + group, nlmode, bias); } } return args;