提交 270b7488 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): support mk4 fp32 im2col

GitOrigin-RevId: 178d7231726c18bbc2586797f2b14e4ef3fdb969
上级 45e2beea
...@@ -749,7 +749,7 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, ...@@ -749,7 +749,7 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
const int8_t* inptr1 = inptr0 + ldin; const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin; const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin; const int8_t* inptr3 = inptr2 + ldin;
int8_t* output = outptr + start_y * out_offset; int8_t* output = outptr + (y - y0) / 4 * out_offset;
prefetch_2x(inptr0); prefetch_2x(inptr0);
prefetch_2x(inptr1); prefetch_2x(inptr1);
prefetch_2x(inptr2); prefetch_2x(inptr2);
...@@ -776,7 +776,7 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, ...@@ -776,7 +776,7 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr,
} }
for (; y + 3 < ymax; y += 4, start_y++) { for (; y + 3 < ymax; y += 4, start_y++) {
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4;
int8_t* output = outptr + start_y * out_offset; int8_t* output = outptr + (y - y0) / 4 * out_offset;
prefetch_2x(inptr0); prefetch_2x(inptr0);
int K = kmax - k0; int K = kmax - k0;
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
......
...@@ -227,7 +227,7 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, ...@@ -227,7 +227,7 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr,
const int8_t* inptr1 = inptr0 + ldin; const int8_t* inptr1 = inptr0 + ldin;
const int8_t* inptr2 = inptr1 + ldin; const int8_t* inptr2 = inptr1 + ldin;
const int8_t* inptr3 = inptr2 + ldin; const int8_t* inptr3 = inptr2 + ldin;
int8_t* output = outptr + start_y * out_offset; int8_t* output = outptr + (y - y0) / 4 * out_offset;
prefetch_2x(inptr0); prefetch_2x(inptr0);
prefetch_2x(inptr1); prefetch_2x(inptr1);
prefetch_2x(inptr2); prefetch_2x(inptr2);
...@@ -254,7 +254,7 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr, ...@@ -254,7 +254,7 @@ static void gemm_mk4_s8_4x2_pack_A(dt_int8* outptr, const dt_int8* inptr,
} }
for (; y + 3 < ymax; y += 4, start_y++) { for (; y + 3 < ymax; y += 4, start_y++) {
const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4; const int8_t* inptr0 = inptr + start_y * ldin + k0 * 4;
int8_t* output = outptr + start_y * out_offset; int8_t* output = outptr + (y - y0) / 4 * out_offset;
prefetch_2x(inptr0); prefetch_2x(inptr0);
int K = kmax - k0; int K = kmax - k0;
for (; K > 15; K -= 16) { for (; K > 15; K -= 16) {
......
...@@ -22,20 +22,6 @@ namespace conv1x1 { ...@@ -22,20 +22,6 @@ namespace conv1x1 {
namespace { namespace {
size_t get_format_pack_size(param::ConvBias::Format format) {
switch(format){
case param::ConvBias::Format::NCHW44:
case param::ConvBias::Format::NCHW4:
return 4_z;
case param::ConvBias::Format::NCHW88:
return 8_z;
case param::ConvBias::Format::NCHW:
return 1_z;
default:
megdnn_throw("unknow pack size of the format");
}
}
struct StrategyHashParam { struct StrategyHashParam {
ConvBiasImpl::NCBKernSizeParam param; ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format; param::ConvBias::Format format;
......
...@@ -125,13 +125,10 @@ public: ...@@ -125,13 +125,10 @@ public:
size_t oc_tile_size) { size_t oc_tile_size) {
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1]; FW = param.filter_meta.spatial[1];
size_t pack_oc_size = 1; size_t pack_oc_size = get_format_pack_size(param.filter_meta.format);
size_t im2col = 0, packb = 0, bias_temp = 0; size_t im2col = 0, packb = 0, bias_temp = 0;
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
megdnn_assert(default_pack, "only support default packa"); megdnn_assert(default_pack, "only support default packa");
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
pack_oc_size = 4;
}
size_t im2col_dst_size = size_t im2col_dst_size =
IC * FH * FW * ohw_tile_size * sizeof(param.src_type); IC * FH * FW * ohw_tile_size * sizeof(param.src_type);
size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size * size_t matmul_dst_size = pack_oc_size * oc_tile_size * ohw_tile_size *
...@@ -321,14 +318,17 @@ fallback::MatrixMulImpl::KernSizeParam ...@@ -321,14 +318,17 @@ fallback::MatrixMulImpl::KernSizeParam
ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
size_t ohw_tile_size, size_t ohw_tile_size,
size_t oc_tile_size) const { size_t oc_tile_size) const {
bool is_nchw44 = auto format = param::MatrixMul::Format::DEFAULT;
param.filter_meta.format == param::ConvBias::Format::NCHW44; size_t pack_oc_size = get_format_pack_size(param.filter_meta.format);
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4;
}
size_t M = oc_tile_size; size_t M = oc_tile_size;
size_t N = ohw_tile_size; size_t N = ohw_tile_size;
size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] * size_t K = param.filter_meta.icpg * param.filter_meta.spatial[0] *
param.filter_meta.spatial[1]; param.filter_meta.spatial[1];
size_t pack_oc_size = is_nchw44 ? 4 : 1; size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N,
size_t LDA = pack_oc_size * K, LDB = pack_oc_size * N, LDC = N; LDC = N * pack_oc_size;
bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 && bool is_dst_8bit = (param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
...@@ -345,8 +345,7 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, ...@@ -345,8 +345,7 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
false, false,
false, false,
param::MatrixMul::ComputeMode::DEFAULT, param::MatrixMul::ComputeMode::DEFAULT,
is_nchw44 ? param::MatrixMul::Format::MK4 format};
: param::MatrixMul::Format::DEFAULT};
} }
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
...@@ -356,11 +355,7 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( ...@@ -356,11 +355,7 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
size_t nr_threads = param.nr_threads; size_t nr_threads = param.nr_threads;
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
size_t ohw = param.osz[0] * param.osz[1]; size_t ohw = param.osz[0] * param.osz[1];
//! pay attention please, should not change the 2 line code,
//! the opr use the same im2col algo, via choice_ohw_oc_block may change the
//! m_ohw_tile_size and m_oc_tile_size, if the two value changed, the
//! workspace size may change, will ocur workspace not match problem, so
//! should use the original data init them to avoid the problem
oc_tile_size = DEFAULT_OC_TILE_SIZE; oc_tile_size = DEFAULT_OC_TILE_SIZE;
ohw_tile_size = m_ohw_tile_size; ohw_tile_size = m_ohw_tile_size;
...@@ -505,14 +500,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -505,14 +500,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
size_t packa_parallel_times = 0; size_t packa_parallel_times = 0;
size_t pack_oc_size = size_t pack_oc_size = get_format_pack_size(param.filter_meta.format);
(param.filter_meta.format == param::ConvBias::Format::NCHW ? 1
: 4);
if (only_packA) { if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
} else if (default_pack) { } else if (default_pack) {
packa_parallel_times = div_ceil<size_t>( packa_parallel_times = div_ceil<size_t>(
OC, m_matmul_algo->get_inner_block_size().m * pack_oc_size); OC, m_matmul_algo->get_inner_block_size().m);
} }
auto matmul_param = get_matmul_kern_param( auto matmul_param = get_matmul_kern_param(
...@@ -659,12 +653,16 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -659,12 +653,16 @@ bool ConvBiasImpl::AlgoIm2col::usable(
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false; return false;
} }
//! current now im2col only support int8 quantized s8 nchw44 if (opr->param().format == param::ConvBias::Format::NCHW44) {
if (opr->param().format == param::ConvBias::Format::NCHW44 && //! current NCHW44 im2col only support DEFAULT mode matmul
(param.src_type.enumv() == param.filter_type.enumv() && if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) {
(param.src_type.enumv() != DTypeEnum::Int8) &&
(param.src_type.enumv() != DTypeEnum::QuantizedS8))) {
return false; return false;
//! nchw44 hybird mode and channel wise is not support
} else if (param.filter_meta.icpg < 4_z ||
param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) {
return false;
}
} }
size_t oc_tile_size = 0, ohw_tile_size = 0; size_t oc_tile_size = 0, ohw_tile_size = 0;
......
...@@ -221,8 +221,17 @@ public: ...@@ -221,8 +221,17 @@ public:
param::ConvBias::Format format = param.filter_meta.format; param::ConvBias::Format format = param.filter_meta.format;
switch (strategytype) { switch (strategytype) {
case StrategyType::FLOAT: case StrategyType::FLOAT:
if (format == param::ConvBias::Format::NCHW) {
cb1(NCHW, DEFAULT, dt_float32, dt_float32, cb1(NCHW, DEFAULT, dt_float32, dt_float32,
PostprocessMode::FLOAT, "DefaultStrategyType::FLOAT"_hash); PostprocessMode::FLOAT,
"DefaultStrategyType::FLOAT"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
cb1(NCHW44, DEFAULT, dt_float32, dt_float32,
PostprocessMode::FLOAT,
"DefaultStrategyTypeNCHW44::FLOAT"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
}
break; break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16: case StrategyType::FLOAT_FP16:
......
...@@ -75,15 +75,14 @@ public: ...@@ -75,15 +75,14 @@ public:
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode, megdnn::PostprocessMode postprocess_mode, PackMode packmode,
FormatMode format> FormatMode format = FormatMode::NCHW>
class Strategy; class Strategy;
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW> postprocess_mode, PackMode::DEFAULT> : public StrategyBase {
: public StrategyBase {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
...@@ -142,8 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -142,8 +141,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44> postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW44>
: public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, : public Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, postprocess_mode, PackMode::DEFAULT> {
FormatMode::NCHW> {
public: public:
const size_t BUNDLE_PADDING_INDEX = 0; const size_t BUNDLE_PADDING_INDEX = 0;
const size_t BUNDLE_PACKA_INDEX = 1; const size_t BUNDLE_PACKA_INDEX = 1;
...@@ -164,8 +162,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -164,8 +162,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW> postprocess_mode, PackMode::NO_PACK> : public StrategyBase {
: public StrategyBase {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
...@@ -231,8 +228,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -231,8 +228,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW> postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase {
: public StrategyBase {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......
...@@ -26,7 +26,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -26,7 +26,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
copy_padding_kern(WorkspaceBundle bundle, copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
...@@ -93,13 +93,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -93,13 +93,13 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) { size_t) {
bundle.set(param.workspace_ptr); bundle.set(param.workspace_ptr);
fallback::MatrixMulImpl::KernParam matmul_param; fallback::MatrixMulImpl::KernParam matmul_param;
size_t group_id = ncb_index.ndrange_id[0]; size_t group_id = ncb_index.ndrange_id[0];
...@@ -112,19 +112,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -112,19 +112,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_algo->get_packA_type_size(); matmul_algo->get_packA_type_size();
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size;
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) +
group_id * packA_group_size + group_id * packA_group_size + a_panel_offset;
(pack_oc_size == 4 ? 0 : a_panel_offset);
matmul_param.A_ptr = matmul_param.A_ptr =
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); const_cast<src_ctype*>(param.filter<src_ctype>(group_id));
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1],
matmul_algo->get_inner_block_size().m * pack_oc_size); matmul_algo->get_inner_block_size().m);
} }
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
...@@ -193,7 +192,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -193,7 +192,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) { const StrategyParam& sparam) {
...@@ -212,7 +211,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -212,7 +211,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
...@@ -249,7 +248,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -249,7 +248,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) { WorkspaceBundle bundle_thread) {
...@@ -264,12 +263,12 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -264,12 +263,12 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
? bias_temp_ptr ? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>( : static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index))); bias_ptr + sparam.oc_cur_index)));
size_t pack_oc_size = sparam.pack_oc_size;
PostProcess<op_ctype, op_dtype, postprocess_mode>::run( PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode, matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z, param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size, 1_z, sparam.output_block_size, sparam.output_block_oc_size / pack_oc_size, 1_z,
sparam.pack_oc_size); sparam.output_block_size, pack_oc_size);
copy_dst(param, matmul_dst, sparam); copy_dst(param, matmul_dst, sparam);
} }
...@@ -277,7 +276,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -277,7 +276,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) { const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) { if (!sparam.skip_copy_dst) {
...@@ -303,7 +302,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -303,7 +302,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) { const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr = bias_ctype* bias_tmp_ptr =
...@@ -318,7 +317,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -318,7 +317,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT, FormatMode::NCHW>:: postprocess_mode, PackMode::DEFAULT>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
...@@ -342,8 +341,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -342,8 +341,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \ _op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::DEFAULT, \ _op_dtype, _postprocess_mode, PackMode::DEFAULT>;
FormatMode::NCHW>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)
......
...@@ -12,10 +12,9 @@ ...@@ -12,10 +12,9 @@
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86 #if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h" #include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif #endif
using namespace megdnn; using namespace megdnn;
#if MEGDNN_X86 #if MEGDNN_X86
using namespace x86; using namespace x86;
...@@ -101,23 +100,12 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -101,23 +100,12 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS) megdnn::PostprocessMode::NO_PROCESS)
#endif #endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8, INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED) megdnn::PostprocessMode::QUANTIZED)
......
...@@ -27,7 +27,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -27,7 +27,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
copy_padding_kern(WorkspaceBundle bundle, copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
...@@ -90,7 +90,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -90,7 +90,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
...@@ -110,7 +110,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -110,7 +110,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) { const StrategyParam& sparam) {
...@@ -129,7 +129,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -129,7 +129,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
...@@ -162,7 +162,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -162,7 +162,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
...@@ -224,7 +224,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -224,7 +224,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) { WorkspaceBundle bundle_thread) {
...@@ -252,7 +252,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -252,7 +252,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) { const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) { if (!sparam.skip_copy_dst) {
...@@ -274,7 +274,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -274,7 +274,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK, FormatMode::NCHW>:: postprocess_mode, PackMode::NO_PACK>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) { WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
...@@ -298,8 +298,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -298,8 +298,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \ _op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, PackMode::NO_PACK, \ _op_dtype, _postprocess_mode, PackMode::NO_PACK>;
FormatMode::NCHW>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)
......
...@@ -27,7 +27,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -27,7 +27,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
copy_padding_kern(WorkspaceBundle bundle, copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
...@@ -90,7 +90,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -90,7 +90,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
packA_kern(WorkspaceBundle bundle, packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
...@@ -124,7 +124,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -124,7 +124,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) { const StrategyParam& sparam) {
...@@ -143,7 +143,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -143,7 +143,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param, exec_matmul(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle, const StrategyParam& sparam, WorkspaceBundle bundle,
WorkspaceBundle bundle_thread, WorkspaceBundle bundle_thread,
...@@ -181,7 +181,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -181,7 +181,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
...@@ -242,7 +242,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -242,7 +242,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) { WorkspaceBundle bundle_thread) {
...@@ -283,7 +283,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -283,7 +283,7 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA, FormatMode::NCHW>:: postprocess_mode, PackMode::ONLY_PACKA>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) { const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) { if (!sparam.skip_copy_dst) {
...@@ -305,7 +305,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -305,7 +305,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
_op_dtype, _postprocess_mode) \ _op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode, \ _op_dtype, _postprocess_mode, \
PackMode::ONLY_PACKA, FormatMode::NCHW>; PackMode::ONLY_PACKA>;
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)
......
...@@ -26,6 +26,18 @@ ...@@ -26,6 +26,18 @@
using namespace megdnn; using namespace megdnn;
using namespace fallback; using namespace fallback;
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) {
switch(format){
case param::ConvBias::Format::NCHW44:
case param::ConvBias::Format::NCHW4:
return 4_z;
case param::ConvBias::Format::NCHW88:
return 8_z;
default:
return 1_z;
}
}
namespace { namespace {
template <typename T> template <typename T>
void incr_ptr(T*& dst, ptrdiff_t delta) { void incr_ptr(T*& dst, ptrdiff_t delta) {
......
...@@ -21,6 +21,11 @@ ...@@ -21,6 +21,11 @@
namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {
/*!
* \brief get the pack_size according to the format
* */
size_t get_format_pack_size(param::ConvBias::Format format);
/*! /*!
* \brief fallback conv bias forward impl * \brief fallback conv bias forward impl
* *
......
...@@ -9,9 +9,8 @@ ...@@ -9,9 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "src/common/utils.h" #include "src/common/utils.h"
#if MEGDNN_ARMV7 || MEGDNN_AARCH64
#include "src/arm_common/simd_macro/marm_neon.h"
#endif
namespace { namespace {
template <bool is_xcorr, typename dtype> template <bool is_xcorr, typename dtype>
...@@ -268,12 +267,13 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, ...@@ -268,12 +267,13 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
} }
for (int w = cur_remain_w; w < OW; w++) { for (int w = cur_remain_w; w < OW; w++) {
size_t index = ic * IH * IW + (start_h + fh2) * IW + size_t index =
(w + fw2); 4 * (ic * IH * IW + (start_h + fh2) * IW +
dst[i++] = src[4 * index]; (w + fw2));
dst[i++] = src[4 * index + 1]; dst[i++] = src[index];
dst[i++] = src[4 * index + 2]; dst[i++] = src[index + 1];
dst[i++] = src[4 * index + 3]; dst[i++] = src[index + 2];
dst[i++] = src[index + 3];
} }
for (int h = start_h + 1; h < end_h; h++) { for (int h = start_h + 1; h < end_h; h++) {
...@@ -317,26 +317,11 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, ...@@ -317,26 +317,11 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
fh2 = FH - fh - 1; fh2 = FH - fh - 1;
fw2 = FW - fw - 1; fw2 = FW - fw - 1;
} }
#if MEGDNN_ARMV7 || MEGDNN_AARCH64
int w = cur_remain_w;
size_t index = (ic * IH * IW + (start_h + fh2) * IW +
(w + fw2));
for (; w + 3 < end_remain_w; w += 4) {
vst1q_u32(&output[i],
vld1q_u32(&uint32_src[index]));
i += 4;
index += 4;
}
for (; w < end_remain_w; w++) {
output[i++] = uint32_src[index];
}
#else
for (int w = cur_remain_w; w < end_remain_w; w++) { for (int w = cur_remain_w; w < end_remain_w; w++) {
size_t index = (ic * IH * IW + size_t index = (ic * IH * IW +
(start_h + fh2) * IW + (w + fw2)); (start_h + fh2) * IW + (w + fw2));
output[i++] = uint32_src[index]; output[i++] = uint32_src[index];
} }
#endif
} }
} }
} }
...@@ -360,27 +345,11 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst, ...@@ -360,27 +345,11 @@ void img2col_nchw4(const dtype* __restrict src, dtype* __restrict dst,
} }
for (int h = start_h + 1; h < end_h; h++) { for (int h = start_h + 1; h < end_h; h++) {
#if MEGDNN_ARMV7 || MEGDNN_AARCH64
int ow = 0;
size_t index = (ic * IH * IW + (h + fh2) * IW +
(ow + fw2));
for (; ow + 3 < OW; ow += 4) {
vst1q_u32(&output[i],
vld1q_u32(&uint32_src[index]));
i += 4;
index += 4;
}
for (; ow < OW; ow++) {
output[i++] = uint32_src[index++];
}
#else
rep(ow, OW) { rep(ow, OW) {
size_t index = (ic * IH * IW + (h + fh2) * IW + size_t index = (ic * IH * IW + (h + fh2) * IW +
(ow + fw2)); (ow + fw2));
output[i++] = uint32_src[index]; output[i++] = uint32_src[index];
} }
#endif
} }
for (int w = 0; w < end_remain_w; w++) { for (int w = 0; w < end_remain_w; w++) {
......
...@@ -1173,10 +1173,10 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args, ...@@ -1173,10 +1173,10 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 #if MEGDNN_AARCH64 || MEGDNN_ARMV7
#if !__ARM_FEATURE_DOTPROD #if !__ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
...@@ -1187,10 +1187,10 @@ TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) { ...@@ -1187,10 +1187,10 @@ TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44) {
#undef cb #undef cb
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true); get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
...@@ -1202,12 +1202,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) { ...@@ -1202,12 +1202,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_MULTI) {
#undef cb #undef cb
} }
TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S2) {
UniformIntRNG rng{-50, 50}; UniformIntRNG rng{-50, 50};
#define cb(name) \ #define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \ checker_conv_bias(get_nchw44_conv_bias_args({3, 4, 6}, 2), handle(), &rng, \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); dtype::QuantizedS8(60.25f), name);
float epsilon = 0.001; float epsilon = 0.001;
...@@ -1220,12 +1221,12 @@ TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) { ...@@ -1220,12 +1221,12 @@ TEST_F(ARM_COMMON, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44) {
} }
TEST_F(ARM_COMMON_MULTI_THREADS, TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_MULTI) { CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44_S1) {
UniformIntRNG rng{-50, 50}; UniformIntRNG rng{-50, 50};
#define cb(name) \ #define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1), \ checker_conv_bias(get_nchw44_conv_bias_args({2, 5, 7}, 1), handle(), &rng, \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \ epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); dtype::QuantizedS8(60.25f), name);
float epsilon = 0.001; float epsilon = 0.001;
...@@ -1286,6 +1287,24 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { ...@@ -1286,6 +1287,24 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
#undef cb #undef cb
} }
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 4, 7}, 1);
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
}
#endif
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3, 5, 6}, 2);
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
}
#endif
/***************************** Conv1x1 Algo Test ***********************/ /***************************** Conv1x1 Algo Test ***********************/
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
using namespace conv_bias; using namespace conv_bias;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册