提交 235d81dd 编写于 作者: M Megvii Engine Team

feat(dnn): add fp16 nchw88 im2col algo

GitOrigin-RevId: a6d6cb4fc7ddcaaa84763731943bde49580b2bfc
上级 f7d2017e
......@@ -407,6 +407,11 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_16x12x1::get_kern(
return kern_mk8_16x12x1;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(
AlgoF16MK8_16x12x1, megdnn_aarch64_matmul_kern, "AlogF16MK8_16x12x1Impl"_hash,
aarch64::matmul::hgemm_mk8_16x12, dt_float16, dt_float16, AlgoDataType::FLOAT16,
MK8);
#endif
#if MGB_ENABLE_DOT
......
......@@ -93,7 +93,7 @@ public:
bool usable(const KernSizeParam&) const override;
size_t get_workspace(const KernSizeParam&) const override;
kern_t get_kern(const KernSizeParam&) const override;
MEGDNN_OVERRIDE_MATMUL_DESC(16, 12, 1, 2, AlgoDataType::FLOAT16, MK8);
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_F16_MK8_16X12X1);
};
......
......@@ -9,8 +9,8 @@
template <>
void matmul_mk8_16x12::kern<M_BLOCK, N_BLOCK>(
const dt_float16* packedA, const dt_float16* packedB, int K,
dt_float16* out, int LDC, bool is_first_k) {
const dt_float16* packedA, const dt_float16* packedB, int K, dt_float16* out,
int LDC, bool is_first_k) {
#define IF_M_GT(M, INSTRUC) ".if " STR(M_BLOCK) " > " #M "\n" INSTRUC ".endif\n"
#define IF_N_GT(N, INSTRUC) ".if " STR(N_BLOCK) " > " #N "\n" INSTRUC ".endif\n"
// clang-format off
......
......@@ -26,6 +26,8 @@ static fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param(
format = param::MatrixMul::Format::MK4;
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) {
format = param::MatrixMul::Format::MK4_DOT;
} else if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
format = param::MatrixMul::Format::MK8;
}
size_t M = oc_tile_size;
size_t N = ohw_tile_size;
......@@ -329,9 +331,15 @@ bool ConvBiasImpl::AlgoIm2col::usable(
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
if (format != param::ConvBias::Format::NCHW &&
format != param::ConvBias::Format::NCHW44 &&
format != param::ConvBias::Format::NCHW44_DOT) {
format != param::ConvBias::Format::NCHW44_DOT &&
format != param::ConvBias::Format::NCHW88) {
return false;
}
if (format == param::ConvBias::Format::NCHW88) {
if (matmul_desc.packmode != Pack_Mode::DEFAULT) {
return false;
}
}
if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul
......
......@@ -248,8 +248,18 @@ public:
break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT_FP16"_hash);
if (format == param::ConvBias::Format::NCHW) {
cb1(NCHW, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT_FP16"_hash);
} else if (format == param::ConvBias::Format::NCHW88) {
cb1(NCHW88, DEFAULT, dt_float16, __fp16, PostprocessMode::FLOAT,
"DefaultStrategyTypeNCHW88::FLOAT_FP16"_hash);
} else {
megdnn_throw(ssprintf(
"Current only support layout NCHW/NCHW88 for im2col algo "
"of float 16, but got %d\n",
uint32_t(format)));
}
break;
#endif
#if !MEGDNN_DISABLE_FLOAT16
......
......@@ -343,6 +343,32 @@ public:
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype,
typename op_dtype, megdnn::PostprocessMode postprocess_mode>
class Strategy<
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode,
PackMode::DEFAULT, FormatMode::NCHW88>
: public Strategy<
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT> {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_PACKB_INDEX = 0;
constexpr static size_t THREAD_BUNDLE_IM2COL_INDEX = 1;
constexpr static size_t THREAD_BUNDLE_BIAS_INDEX = 2;
Strategy() = default;
void exec_im2col(
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
};
template <
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype,
typename op_dtype, megdnn::PostprocessMode postprocess_mode>
......
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#else
#include "src/common/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {
template <
typename src_ctype, typename bias_ctype, typename dst_ctype, typename op_ctype,
typename op_dtype, megdnn::PostprocessMode postprocess_mode>
void Strategy<
src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, postprocess_mode,
PackMode::DEFAULT, FormatMode::NCHW88>::
exec_im2col(
const WorkspaceBundle& bundle, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernParam matmul_param,
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) {
size_t sh = param.filter_meta.stride[0];
size_t sw = param.filter_meta.stride[1];
size_t ow = param.osz[1];
size_t ic = param.filter_meta.icpg;
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2;
size_t iw = param.isz[1] + param.filter_meta.padding[1] * 2;
size_t fh = param.filter_meta.spatial[0];
size_t fw = param.filter_meta.spatial[1];
bool is_xcoor = !param.filter_meta.should_flip;
constexpr static size_t pack_size = 8;
size_t input_offset =
ic * ih * iw *
(sparam.group_id + param.filter_meta.group * sparam.batch_id) *
sizeof(src_ctype);
src_ctype* src = reinterpret_cast<src_ctype*>(
reinterpret_cast<uintptr_t>(bundle.get(BUNDLE_PADDING_INDEX)) +
input_offset);
bool is_phpwzero =
(param.filter_meta.padding[0] == 0 && param.filter_meta.padding[1] == 0);
if (is_phpwzero) {
src = const_cast<src_ctype*>(
param.src<src_ctype>(sparam.batch_id, sparam.group_id));
}
src_ctype* im2col_dst =
reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_IM2COL_INDEX));
if (sh == 1 && sw == 1) {
if (is_xcoor) {
img2col_nchw8<true>(
src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index,
sparam.output_block_size);
} else {
img2col_nchw8<false>(
src, im2col_dst, ow, ic, ih, iw, fh, fw, sparam.ohw_cur_index,
sparam.output_block_size);
}
} else {
if (is_xcoor) {
img2col_stride_nchw8<true>(
src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw,
sparam.ohw_cur_index, sparam.output_block_size);
} else {
img2col_stride_nchw8<false>(
src, im2col_dst, ow, ic, ih, iw, fh, fw, sh, sw,
sparam.ohw_cur_index, sparam.output_block_size);
}
}
matmul_param.M = sparam.output_block_oc_size;
matmul_param.N = sparam.output_block_size;
matmul_param.LDB = pack_size * sparam.output_block_size;
matmul_param.LDC = pack_size * sparam.output_block_size;
matmul_param.B_ptr = im2col_dst;
src_ctype* b_panel =
reinterpret_cast<src_ctype*>(bundle_thread.get(THREAD_BUNDLE_PACKB_INDEX));
matmul_algo->pack_B(matmul_param, b_panel, 0, matmul_param.N);
}
#define INSTANTIAL_CLASS( \
_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, _postprocess_mode) \
template class Strategy< \
_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, _op_dtype, \
_postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW88>;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(
dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT);
#endif
#undef INSTANTIAL_CLASS
} // namespace megdnn
\ No newline at end of file
......@@ -347,6 +347,441 @@ void img2col_nchw4(
}
}
template <bool is_xcorr, typename dtype>
void img2col_nchw8(
const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC,
const int IH, const int IW, const int FH, const int FW, const int cur_index,
const int block_size) {
int start_h = cur_index / OW;
int cur_n_remain = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_n_remain = (cur_index + block_size) % OW;
bool same_line = (start_h == end_h);
int IC_div_8 = IC / 8;
if (sizeof(dtype) == 2) {
if (same_line) {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
//! TODO: Substitute GI for arm intrinsic when GI supports FP16
//! data type.
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW +
cur_n_remain + fw2);
for (int w = cur_n_remain; w < end_n_remain; ++w) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
src_idx));
dst_idx += 8;
src_idx += 8;
}
#else
int src_idx = 2 * (ic * IH * IW + (start_h + fh2) * IW +
cur_n_remain + fw2);
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src);
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst);
for (int w = cur_n_remain; w < end_n_remain; w++) {
u64_dst[dst_idx] = u64_src[src_idx];
u64_dst[dst_idx + 1] = u64_src[src_idx + 1];
dst_idx += 2;
src_idx += 2;
}
#endif
}
}
}
} else {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
int src_idx = 8 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 +
cur_n_remain);
for (int w = cur_n_remain; w < OW; ++w) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
src_idx));
dst_idx += 8;
src_idx += 8;
}
src_idx = 8 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2);
for (int h = start_h + 1; h < end_h; ++h) {
int _src_idx = src_idx;
rep(w, OW) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
_src_idx));
dst_idx += 8;
_src_idx += 8;
}
src_idx += IW * 8;
}
src_idx = 8 * (ic * IH * IW + (fh2 + end_h) * IW + fw2);
rep(w, end_n_remain) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
src_idx));
dst_idx += 8;
src_idx += 8;
}
#else
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src);
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst);
int src_idx = 2 * (ic * IH * IW + (fh2 + start_h) * IW + fw2 +
cur_n_remain);
for (int w = cur_n_remain; w < OW; ++w) {
u64_dst[dst_idx] = u64_src[src_idx];
u64_dst[dst_idx + 1] = u64_src[src_idx + 1];
dst_idx += 2;
src_idx += 2;
}
src_idx = 2 * (ic * IH * IW + (fh2 + start_h + 1) * IW + fw2);
for (int h = start_h + 1; h < end_h; ++h) {
int _src_idx = src_idx;
rep(w, OW) {
u64_dst[dst_idx] = u64_src[_src_idx];
u64_dst[dst_idx + 1] = u64_src[_src_idx + 1];
dst_idx += 2;
_src_idx += 2;
}
src_idx += IW * 2;
}
src_idx = 2 * (ic * IH * IW + (fh2 + end_h) * IW + fw2);
rep(w, end_n_remain) {
u64_dst[dst_idx] = u64_src[src_idx];
u64_dst[dst_idx + 1] = u64_src[src_idx + 1];
dst_idx += 2;
src_idx += 2;
}
#endif
}
}
}
}
} else {
if (same_line) {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 +
cur_n_remain);
for (int w = cur_n_remain; w < end_n_remain; ++w) {
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
}
}
}
}
} else {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
int src_idx = 8 * (ic * IH * IW + (start_h + fh2) * IW + fw2 +
cur_n_remain);
for (int w = cur_n_remain; w < OW; ++w) {
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
}
src_idx = 8 * (ic * IH * IW + (start_h + 1 + fh2) * IW + fw2);
for (int h = start_h + 1; h < end_h; ++h) {
rep(w, OW) {
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
}
}
src_idx = 8 * (ic * IH * IW + (end_h + fh2) * IW + fw2);
rep(w, end_n_remain) {
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
dst[dst_idx++] = src[src_idx++];
}
}
}
}
}
}
}
template <bool is_xcorr, typename dtype>
void img2col_stride_nchw8(
const dtype* __restrict src, dtype* __restrict dst, const int OW, const int IC,
const int IH, const int IW, const int FH, const int FW, const int SH,
const int SW, const int cur_index, const int block_size) {
int start_h = cur_index / OW;
int cur_n_remain = cur_index % OW;
int end_h = (cur_index + block_size) / OW;
int end_n_remain = (cur_index + block_size) % OW;
bool same_line = (start_h == end_h);
int IC_div_8 = IC / 8;
if (sizeof(dtype) == 2) {
if (same_line) {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW +
cur_n_remain * SW + fw2);
for (int w = cur_n_remain; w < end_n_remain; ++w) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
src_idx));
dst_idx += 8;
src_idx += 8 * SW;
}
#else
int src_idx = 2 * (ic * IH * IW + (start_h * SH + fh2) * IW +
cur_n_remain * SW + fw2);
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src);
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst);
for (int w = cur_n_remain; w < end_n_remain; w++) {
u64_dst[dst_idx] = u64_src[src_idx];
u64_dst[dst_idx + 1] = u64_src[src_idx + 1];
dst_idx += 2;
src_idx += 2 * SW;
}
#endif
}
}
}
} else {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
int src_idx = 8 * (ic * IH * IW + (fh2 + start_h * SH) * IW +
fw2 + cur_n_remain * SW);
for (int w = cur_n_remain; w < OW; ++w) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
src_idx));
dst_idx += 8;
src_idx += 8 * SW;
}
src_idx = 8 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW +
fw2);
for (int h = start_h + 1; h < end_h; ++h) {
int _src_idx = src_idx;
rep(w, OW) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
_src_idx));
dst_idx += 8;
_src_idx += 8 * SW;
}
src_idx += IW * 8 * SH;
}
src_idx = 8 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2);
rep(w, end_n_remain) {
vst1q_f16(
reinterpret_cast<__fp16*>(dst) + dst_idx,
vld1q_f16(
reinterpret_cast<const __fp16*>(src) +
src_idx));
dst_idx += 8;
src_idx += 8 * SW;
}
#else
uint64_t* u64_src = reinterpret_cast<uint64_t*>(src);
uint64_t* u64_dst = reinterpret_cast<uint64_t*>(dst);
int src_idx = 2 * (ic * IH * IW + (fh2 + start_h * SH) * IW +
fw2 + cur_n_remain * SW);
for (int w = cur_n_remain; w < OW; ++w) {
u64_dst[dst_idx] = u64_src[src_idx];
u64_dst[dst_idx + 1] = u64_src[src_idx + 1];
dst_idx += 2;
src_idx += 2 * SW;
}
src_idx = 2 * (ic * IH * IW + (fh2 + (start_h + 1) * SH) * IW +
fw2);
for (int h = start_h + 1; h < end_h; ++h) {
int _src_idx = src_idx;
rep(w, OW) {
u64_dst[dst_idx] = u64_src[_src_idx];
u64_dst[dst_idx + 1] = u64_src[_src_idx + 1];
dst_idx += 2;
_src_idx += 2 * SW;
}
src_idx += IW * 2 * SH;
}
src_idx = 2 * (ic * IH * IW + (fh2 + end_h * SH) * IW + fw2);
rep(w, end_n_remain) {
u64_dst[dst_idx] = u64_src[src_idx];
u64_dst[dst_idx + 1] = u64_src[src_idx + 1];
dst_idx += 2;
src_idx += 2 * SW;
}
#endif
}
}
}
}
} else {
if (same_line) {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW +
fw2 + cur_n_remain * SW);
for (int w = cur_n_remain; w < end_n_remain; ++w) {
dst[dst_idx++] = src[src_idx];
dst[dst_idx++] = src[src_idx + 1];
dst[dst_idx++] = src[src_idx + 2];
dst[dst_idx++] = src[src_idx + 3];
dst[dst_idx++] = src[src_idx + 4];
dst[dst_idx++] = src[src_idx + 5];
dst[dst_idx++] = src[src_idx + 6];
dst[dst_idx++] = src[src_idx + 7];
src_idx += 8 * SW;
}
}
}
}
} else {
int dst_idx = 0;
rep(ic, IC_div_8) {
rep(fh, FH) {
rep(fw, FW) {
int fh2 = fh, fw2 = fw;
if (!is_xcorr) {
fh2 = FH - fh - 1;
fw2 = FW - fw - 1;
}
int src_idx = 8 * (ic * IH * IW + (start_h * SH + fh2) * IW +
fw2 + cur_n_remain * SW);
for (int w = cur_n_remain; w < OW; ++w) {
dst[dst_idx++] = src[src_idx];
dst[dst_idx++] = src[src_idx + 1];
dst[dst_idx++] = src[src_idx + 2];
dst[dst_idx++] = src[src_idx + 3];
dst[dst_idx++] = src[src_idx + 4];
dst[dst_idx++] = src[src_idx + 5];
dst[dst_idx++] = src[src_idx + 6];
dst[dst_idx++] = src[src_idx + 7];
src_idx += 8 * SW;
}
src_idx = 8 * (ic * IH * IW + ((start_h + 1) * SH + fh2) * IW +
fw2);
for (int h = start_h + 1; h < end_h; ++h) {
rep(w, OW) {
dst[dst_idx++] = src[src_idx];
dst[dst_idx++] = src[src_idx + 1];
dst[dst_idx++] = src[src_idx + 2];
dst[dst_idx++] = src[src_idx + 3];
dst[dst_idx++] = src[src_idx + 4];
dst[dst_idx++] = src[src_idx + 5];
dst[dst_idx++] = src[src_idx + 6];
dst[dst_idx++] = src[src_idx + 7];
src_idx += 8 * SW;
}
}
src_idx = 8 * (ic * IH * IW + (end_h * SH + fh2) * IW + fw2);
rep(w, end_n_remain) {
dst[dst_idx++] = src[src_idx];
dst[dst_idx++] = src[src_idx + 1];
dst[dst_idx++] = src[src_idx + 2];
dst[dst_idx++] = src[src_idx + 3];
dst[dst_idx++] = src[src_idx + 4];
dst[dst_idx++] = src[src_idx + 5];
dst[dst_idx++] = src[src_idx + 6];
dst[dst_idx++] = src[src_idx + 7];
src_idx += 8 * SW;
}
}
}
}
}
}
}
template <bool is_xcorr, typename dtype>
void img2col_stride(
const dtype* __restrict src, dtype* __restrict dst, const int OC, const int OH,
......
......@@ -68,6 +68,87 @@ void benchmark_impl(
multi_thread_config.nr_thread);
}
}
void benchmark_with_contrast(
const std::vector<conv_bias::TestArg>& args, const std::string algo_name,
std::vector<DType>& data_type,
const std::vector<conv_bias::TestArg>& args_contrast,
const std::string algo_name_contrast, std::vector<DType>& data_type_contrast,
size_t RUNS, TaskExecutorConfig&& single_thread_config) {
auto single_thread_handle = create_cpu_handle(0, true, &single_thread_config);
auto benchmarker = Benchmarker<ConvBias>(single_thread_handle.get());
auto benchmarker_contrast = Benchmarker<ConvBias>(single_thread_handle.get());
benchmarker.set_times(RUNS)
.set_display(false)
.set_dtype(0, data_type[0])
.set_dtype(1, data_type[1])
.set_dtype(2, data_type[2])
.set_dtype(4, data_type[3])
.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name.c_str()));
benchmarker_contrast.set_times(RUNS)
.set_display(false)
.set_dtype(0, data_type_contrast[0])
.set_dtype(1, data_type_contrast[1])
.set_dtype(2, data_type_contrast[2])
.set_dtype(4, data_type_contrast[3])
.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
algo_name_contrast.c_str()));
size_t arg_size = args.size(), arg_contrast_size = args_contrast.size();
megdnn_assert(arg_size == arg_contrast_size);
rep(i, arg_size) {
TensorLayout dst_layout, dst_layout_contrast;
auto opr = single_thread_handle.get()->create_operator<ConvBias>();
auto&& arg = args[i];
opr->param() = arg.param;
opr->deduce_layout(
{arg.src, data_type[0]}, {arg.filter, data_type[1]},
{arg.bias, data_type[2]}, {}, dst_layout);
float computation = (dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * arg.filter[4] * 2.0) /
(1024 * 1024 * 1024) * 1e3;
benchmarker.set_param(arg.param);
auto used = benchmarker.exec({arg.src, arg.filter, arg.bias, {}, {}}) / RUNS;
auto&& arg_contrast = args_contrast[i];
opr->param() = arg_contrast.param;
opr->deduce_layout(
{arg_contrast.src, data_type_contrast[0]},
{arg_contrast.filter, data_type_contrast[1]},
{arg_contrast.bias, data_type_contrast[2]}, {}, dst_layout_contrast);
float computation_contrast =
(dst_layout_contrast.total_nr_elems() * arg_contrast.filter[1] *
arg_contrast.filter[2] * arg_contrast.filter[3] *
arg_contrast.filter[4] * 2.0) /
(1024 * 1024 * 1024) * 1e3;
benchmarker_contrast.set_param(arg_contrast.param);
auto used_contrast = benchmarker_contrast.exec(
{arg_contrast.src,
arg_contrast.filter,
arg_contrast.bias,
{},
{}}) /
RUNS;
printf("Bench case: \n");
printf("padding: %u, stride: %u, nonline mode: %u\n", arg.param.pad_h,
arg.param.stride_h, arg.param.nonlineMode);
printf("%s %s %s\n", arg.src.to_string().c_str(),
arg.filter.to_string().c_str(), arg.bias.to_string().c_str());
printf("%s %s %s\n", arg_contrast.src.to_string().c_str(),
arg_contrast.filter.to_string().c_str(),
arg_contrast.bias.to_string().c_str());
printf("%s: %f gflops;\n%s: %f gflops\n"
"spead up = %f\n",
algo_name.c_str(), computation / used, algo_name_contrast.c_str(),
computation_contrast / used_contrast, used_contrast / used);
}
}
} // namespace
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -1591,6 +1672,91 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_FP32) {
data_type);
shapes_and_computation.clear();
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_IM2COL_NCHW44_VS_NCHW88) {
constexpr size_t RUNS = 50;
using NLMode = param::ConvBias::NonlineMode;
std::vector<conv_bias::TestArg> args_nchw88, args_nchw44;
auto bench_case = [&](size_t N, size_t IC, size_t OC, size_t H, size_t W, size_t FS,
size_t group) {
param::ConvBias param_nchw88, param_nchw44;
param_nchw88.format = param::ConvBias::Format::NCHW88;
param_nchw44.format = param::ConvBias::Format::NCHW44;
for (size_t pad : {1, 2, 4}) {
for (size_t stride : {1, 2, 3}) {
for (auto nlmode :
{NLMode::RELU, NLMode::IDENTITY, NLMode::SIGMOID,
NLMode::H_SWISH}) {
param_nchw88.nonlineMode = nlmode;
param_nchw88.pad_h = pad;
param_nchw88.pad_w = pad;
param_nchw88.stride_h = stride;
param_nchw88.stride_w = stride;
param_nchw44.nonlineMode = nlmode;
param_nchw44.pad_h = pad;
param_nchw44.pad_w = pad;
param_nchw44.stride_h = stride;
param_nchw44.stride_w = stride;
args_nchw88.emplace_back(
param_nchw88, TensorShape{N, IC / 8, H, W, 8},
TensorShape{OC / 8, IC / group / 8, FS, FS, 8, 8},
TensorShape{1, OC / 8, 1, 1, 8});
args_nchw44.emplace_back(
param_nchw44, TensorShape{N, IC / 4, H, W, 4},
TensorShape{OC / 4, IC / group / 4, FS, FS, 4, 4},
TensorShape{1, OC / 4, 1, 1, 4});
}
}
}
};
std::vector<DType> data_type_fp16 = {
dtype::Float16(), dtype::Float16(), dtype::Float16(), dtype::Float16()};
std::vector<DType> data_type_fp32 = {
dtype::Float32(), dtype::Float32(), dtype::Float32(), dtype::Float32()};
bench_case(1, 32, 32, 300, 300, 3, 1);
bench_case(1, 32, 32, 400, 400, 3, 1);
bench_case(1, 32, 32, 100, 100, 3, 1);
bench_case(1, 32, 32, 80, 80, 3, 1);
bench_case(1, 32, 64, 200, 200, 3, 1);
bench_case(1, 32, 64, 128, 128, 3, 1);
bench_case(1, 32, 64, 100, 100, 3, 1);
bench_case(1, 32, 64, 80, 80, 3, 1);
bench_case(1, 32, 128, 200, 200, 3, 1);
bench_case(1, 32, 128, 128, 128, 3, 1);
bench_case(1, 32, 128, 100, 100, 3, 1);
bench_case(1, 32, 128, 80, 80, 3, 1);
bench_case(1, 64, 32, 7, 7, 3, 1);
bench_case(1, 64, 64, 7, 7, 3, 1);
bench_case(1, 64, 128, 7, 7, 3, 1);
bench_case(1, 64, 256, 7, 7, 3, 1);
bench_case(1, 64, 512, 7, 7, 3, 1);
bench_case(1, 64, 1024, 7, 7, 3, 1);
bench_case(1, 64, 32, 14, 14, 3, 1);
bench_case(1, 64, 64, 14, 14, 3, 1);
bench_case(1, 64, 128, 14, 14, 3, 1);
bench_case(1, 64, 256, 14, 14, 3, 1);
bench_case(1, 64, 512, 14, 14, 3, 1);
bench_case(1, 64, 1024, 14, 14, 3, 1);
bench_case(1, 128, 128, 14, 14, 3, 1);
bench_case(1, 128, 256, 14, 14, 3, 1);
bench_case(1, 512, 512, 14, 14, 3, 1);
bench_case(1, 256, 512, 14, 14, 3, 1);
bench_case(1, 512, 1024, 14, 14, 3, 1);
bench_case(1, 1024, 1024, 14, 14, 3, 1);
std::string algo_name_nchw88 = "IM2COLMATMUL:AARCH64_F16_MK8_16X12X1:96";
std::string algo_name_nchw44 = "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1:96";
benchmark_with_contrast(
args_nchw88, algo_name_nchw88, data_type_fp16, args_nchw44,
algo_name_nchw44, data_type_fp32, RUNS, {1, {4}});
}
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CHANNEL_WISE_INT8_INT8_INT8_STRIDE1) {
constexpr size_t RUNS = 50;
......
......@@ -362,6 +362,30 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_MK8_FP16) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw88_conv_bias_args(
{2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 1);
auto args1 = get_nchw88_conv_bias_args(
{2, 3, 4, 5, 6, 7}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 2, 3);
args.insert(args.begin(), args1.begin(), args1.begin());
args1 = get_nchw88_conv_bias_args(
{2, 3, 4, 5, 6, 7, 9}, QUAN_NLMODE, BR_AND_BIAS_BIASMODE, 3, 4);
args.insert(args.begin(), args1.begin(), args1.begin());
NormalRNG rng(1);
#define cb(name) \
checker_conv_bias_common( \
args, handle(), &rng, 0.03, dtype::Float16{}, dtype::Float16{}, \
dtype::Float16{}, dtype::Float16{}, name);
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_F16_MK8_16X12X1");
#endif
#undef cb
}
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
......
......@@ -161,6 +161,24 @@ TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_INT8_INT16_INT32) {
run();
}
TEST_F(ARM_COMMON, ELEMWISE_SIGMOID) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
checker.set_epsilon(1e-3);
checker.set_dtype(0, dtype::Float16());
checker.set_param(Mode::SIGMOID);
for (size_t n : {1, 2, 3}) {
for (size_t ic : {8, 16, 24, 32}) {
for (size_t ih : {5, 10, 15, 20, 21, 37}) {
for (size_t iw : {7, 9, 11, 13, 14, 20, 35}) {
checker.exec({{n, ic, ih, iw}, {}});
}
}
}
}
}
TEST_F(ARM_COMMON, ELEMWISE_FORWARD_NCHW44_FP32) {
using Mode = ElemwiseForward::Param::Mode;
Checker<ElemwiseForward> checker(handle());
......
......@@ -98,6 +98,9 @@ TEST_F(ARM_COMMON, BENCHMARK_ELEMWISE_UNARY) {
BENCHMARK_CASES_INT(shape, dtype::Int16());
BENCHMARK_CASES_INT(shape, dtype::Int8());
BENCHMARK_CASES_FLOAT(shape, dtype::Float32());
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
BENCHMARK_CASES_FLOAT(shape, dtype::Float16());
#endif
#undef BENCHMARK_CASES_INT
#undef BENCHMARK_CASES_FLOAT
#undef RUN
......
......@@ -1580,17 +1580,19 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args(
std::vector<size_t> kernel_vec,
std::vector<param::ConvBias::NonlineMode> nlmode_vec,
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride) {
std::vector<megdnn::BiasMode> biasmode_vec, size_t stride, int pad) {
using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, size_t kernel,
size_t stride, size_t group, NLMode nlmode,
size_t stride, int pad, size_t group, NLMode nlmode,
megdnn::BiasMode bias_mode) {
constexpr int pack_c = 8;
const size_t pad = kernel / 2;
if (pad == -1) {
pad = kernel / 2;
}
auto oc_per_group = oc / group;
auto ic_per_group = ic / group;
......@@ -1651,8 +1653,8 @@ std::vector<conv_bias::TestArg> get_nchw88_conv_bias_args(
if (kernel < h || kernel < w) {
continue;
}
pack(n, oc, ic, h, w, kernel, stride, group,
nlmode, bias);
pack(n, oc, ic, h, w, kernel, stride, pad,
group, nlmode, bias);
}
}
return args;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册