提交 48ac1e1a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/fallback): delete nopack onlypacka noneed datatype,and add

im2co and conv1x1 mk4_dot support

GitOrigin-RevId: 096b16a3abb8fa259db77e792bd286dacf3fd8c3
上级 3117bfb7
......@@ -913,10 +913,10 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
*outptr++ = *inptr++;
}
for (; i < 4; i++) {
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = 0;
*outptr++ = 0;
*outptr++ = 0;
*outptr++ = 0;
}
}
......
......@@ -39,7 +39,7 @@ namespace {
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \
N* OC* OH* OW);
N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common:: \
......@@ -63,7 +63,7 @@ namespace {
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N* OC* OH* OW);
dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_mode) \
switch (_mode) { \
......@@ -113,7 +113,6 @@ struct PostProcess {
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
FOR_BIAS(bias_mode)
}
};
......@@ -155,7 +154,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
_op<opctype, opdtype>, \
megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW);
bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
......@@ -173,8 +173,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \
#define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \
_caller(_op) break;
#define FOR_NONLINEAR(_caller) \
......
......@@ -729,10 +729,10 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
*outptr++ = *inptr++;
}
for (; i < 4; i++) {
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = *inptr++;
*outptr++ = 0;
*outptr++ = 0;
*outptr++ = 0;
*outptr++ = 0;
}
}
outptr_base += 24;
......
......@@ -187,7 +187,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) {
if (opr->param().format != param::ConvBias::Format::NCHW &&
opr->param().format != param::ConvBias::Format::NCHW44)
opr->param().format != param::ConvBias::Format::NCHW44 &&
opr->param().format != param::ConvBias::Format::NCHW44_DOT)
return false;
size_t FH = param.filter_meta.spatial[0],
......@@ -219,8 +220,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
param.nonlineMode != megdnn::NonlineMode::IDENTITY)
return false;
if (opr->param().format == param::ConvBias::Format::NCHW44) {
//! nchw44 hybird mode and channel wise is not support
if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) {
return false;
......
......@@ -73,32 +73,34 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format) {
size_t pack_size = get_format_pack_size(format);
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
return std::make_unique< \
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, _packmode>>(pack_size); \
} \
} \
size_t pack_c_size = pack_size(format);
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
return std::make_unique< \
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, _packmode>>( \
pack_c_size); \
} \
} \
MIDOUT_END()
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique< \
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, \
_postprocess_mode, _packmode>>(pack_size); \
} \
} \
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique< \
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, \
_postprocess_mode, _packmode>>( \
pack_c_size); \
} \
} \
MIDOUT_END()
switch (pack_mode) {
......
......@@ -12,7 +12,6 @@
#pragma once
#include "megdnn/opr_param_defs.h"
#include "src/fallback/conv_bias/opr_impl.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
......@@ -41,12 +40,15 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param(
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t pack_c_size = 1_z;
size_t pack_c_size = pack_size(param.filter_meta.format);
auto format = param::MatrixMul::Format::DEFAULT;
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){
pack_c_size = 4_z;
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4;
} else if (param.filter_meta.format ==
param::ConvBias::Format::NCHW44_DOT) {
format = param::MatrixMul::Format::MK4_DOT;
}
return {param.filter_type,
param.src_type,
is_dst_8bit ? param.bias_type : param.dst_type,
......
......@@ -15,7 +15,6 @@
#include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/winograd/strategy.h"
#include "src/naive/convolution/helper.h"
#include "midout.h"
......@@ -125,7 +124,7 @@ public:
size_t oc_tile_size) {
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1];
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format);
size_t pack_oc_size = pack_size(param.filter_meta.format);
size_t im2col = 0, packb = 0, bias_temp = 0;
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
megdnn_assert(default_pack, "only support default packa");
......@@ -319,9 +318,11 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
size_t ohw_tile_size,
size_t oc_tile_size) const {
auto format = param::MatrixMul::Format::DEFAULT;
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format);
size_t pack_oc_size = pack_size(param.filter_meta.format);
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4;
} else if(param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT){
format = param::MatrixMul::Format::MK4_DOT;
}
size_t M = oc_tile_size;
size_t N = ohw_tile_size;
......@@ -351,11 +352,10 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
const NCBKernSizeParam& param, size_t& oc_tile_size,
size_t& ohw_tile_size, size_t block_m, size_t block_n,
bool need_pack) const {
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const {
size_t nr_threads = param.nr_threads;
size_t OC = param.filter_meta.ocpg;
size_t ohw = param.osz[0] * param.osz[1];
oc_tile_size = DEFAULT_OC_TILE_SIZE;
ohw_tile_size = m_ohw_tile_size;
......@@ -376,7 +376,8 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
}
}
} else {
if (!need_pack) { //! no pack ,usually in x86 save memroy
//! in no_pack mode don't do block operation when using single thread
if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
ohw_tile_size = ohw;
oc_tile_size = OC;
}
......@@ -406,7 +407,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m,
inner_block.n, need_pack);
inner_block.n, m_matmul_algo->packmode());
auto im2col_kern_param = get_matmul_kern_param(
param, ohw_tile_size, only_packA ? oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
......@@ -418,7 +419,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
need_pack);
m_matmul_algo->packmode());
packa_group_size = 0;
}
......@@ -488,19 +489,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack);
} else { //! not support pack,not need pack
inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! nopack_mode
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
no_pack);
m_matmul_algo->packmode());
}
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
size_t packa_parallel_times = 0;
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format);
size_t pack_oc_size = pack_size(param.filter_meta.format);
if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
......@@ -639,9 +641,15 @@ bool ConvBiasImpl::AlgoIm2col::usable(
ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) {
if (opr->param().format != param::ConvBias::Format::NCHW &&
opr->param().format != param::ConvBias::Format::NCHW44_DOT &&
opr->param().format != param::ConvBias::Format::NCHW44) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not support
//! PostProcess
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not
//! support PostProcess
if (param.src_type.enumv() == param.filter_type.enumv() &&
((param.src_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16 ||
......@@ -653,9 +661,10 @@ bool ConvBiasImpl::AlgoIm2col::usable(
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
if (opr->param().format == param::ConvBias::Format::NCHW44) {
if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul
if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) {
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) {
return false;
//! nchw44 hybird mode and channel wise is not support
} else if (param.filter_meta.icpg < 4_z ||
......@@ -668,29 +677,27 @@ bool ConvBiasImpl::AlgoIm2col::usable(
size_t oc_tile_size = 0, ohw_tile_size = 0;
Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack);
inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! not support pack,not need pack
size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn,
no_pack);
m_matmul_algo->packmode());
}
fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable &&
(opr->param().format == param::ConvBias::Format::NCHW ||
opr->param().format == param::ConvBias::Format::NCHW44) &&
(!(param.filter_meta.spatial[0] ==
param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] == 1) &&
param.filter_meta.spatial[0] == 1 &&
param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1)) &&
(param.filter_meta.dilation[0] ==
......
......@@ -36,10 +36,10 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase {
const NCBKernSizeParam& param, size_t ohw_tile_size,
size_t oc_tile_size) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param,
size_t& oc_tile_size, size_t& ohw_tile_size,
size_t block_m, size_t block_n,
bool pack_default) const;
void choice_ohw_oc_block(
const NCBKernSizeParam& param, size_t& oc_tile_size,
size_t& ohw_tile_size, size_t block_m, size_t block_n,
fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const;
public:
AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size)
......
......@@ -230,7 +230,11 @@ public:
PostprocessMode::FLOAT,
"DefaultStrategyTypeNCHW44::FLOAT"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW for im2col "
"algo, but got %d\n",
uint32_t(format)));
}
break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -252,12 +256,17 @@ public:
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
}
break;
......@@ -288,13 +297,18 @@ public:
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
}
break;
......@@ -304,17 +318,22 @@ public:
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED,
"DefaultStrategyType::QINT8x8x32x8"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
dt_int32, dt_int8, PostprocessMode::QUANTIZED,
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
} else {
megdnn_throw("not support format except nchw44 and nchw\n");
megdnn_throw(ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
}
break;
}
megdnn_throw("error not support strategy type ");
megdnn_throw(ssprintf("Unsupported strategy type %u in default mode",
uint32_t(strategytype)));
}
static std::unique_ptr<StrategyBase> make_nopack_strategy(
......@@ -328,10 +347,6 @@ public:
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT,
"NoPackStrategyType::FLOAT_FP16"_hash);
break;
#else
#if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16:
......@@ -341,48 +356,24 @@ public:
break;
#endif
#endif
case StrategyType::INT8x8x32:
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x32"_hash);
break;
case StrategyType::INT8x8x16:
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x16"_hash);
break;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::QUINT8x8x32"_hash);
break;
case StrategyType::QUINT8x8x32x8:
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
PostprocessMode::QUANTIZED,
"NoPackStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32:
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::QINT8x8x32"_hash);
case StrategyType::INT8x8x32:
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x32"_hash);
break;
case StrategyType::QINT8x8x32x8:
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED,
"NoPackStrategyType::QINT8x8x32x8"_hash);
default:
megdnn_throw(
ssprintf("Unsupported strategy type %u in no_pack mode",
uint32_t(strategytype)));
break;
}
megdnn_throw("error not support strategy type ");
megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode",
uint32_t(strategytype)));
}
static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
......@@ -396,63 +387,14 @@ public:
PostprocessMode::FLOAT,
"OnlyPackaStrategyType::FLOAT"_hash);
break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(NCHW, ONLY_PACKA, dt_float16, __fp16,
PostprocessMode::FLOAT,
"OnlyPackaStrategyType::FLOAT_FP16"_hash);
break;
#else
#if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16:
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash);
break;
#endif
#endif
case StrategyType::INT8x8x32:
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::INT8x8x32"_hash);
break;
case StrategyType::INT8x8x16:
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::INT8x8x16"_hash);
break;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::QUINT8x8x32"_hash);
break;
case StrategyType::QUINT8x8x32x8:
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8,
dt_int32, dt_uint8, PostprocessMode::QUANTIZED,
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32:
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::QINT8x8x32"_hash);
break;
case StrategyType::QINT8x8x32x8:
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED,
"OnlyPackaStrategyType::QINT8x8x32x8"_hash);
default:
megdnn_throw(ssprintf(
"Unsupported strategy type %u in onlypacka mode",
uint32_t(strategytype)));
break;
}
megdnn_throw("error not support strategy type ");
megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode",
uint32_t(strategytype)));
}
#undef cb1
......
......@@ -11,6 +11,16 @@
#pragma once
#include "src/fallback/conv_bias/opr_impl.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
......@@ -72,6 +82,185 @@ public:
const StrategyParam& sparam, WorkspaceBundle bundle_thread) = 0;
};
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode,
FormatMode format>
//! this class is a new base class for StrategyDefault StrategyNoPack and so on,
//! in order to handle copy pad use the same code
class StrategyBridge : public StrategyBase {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
StrategyBridge() = default;
virtual void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) override {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t PH_SIZE = PH * IW2 * pack_oc_size;
PW = PW * pack_oc_size;
IW = IW * pack_oc_size;
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
batch_id, group_id, channel_id, 1, pack_oc_size));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
}
};
namespace{
template <typename bias_ctype>
inline void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
size_t matmul_bundle_index) {
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
return static_cast<void*>(bundle_thread.get(matmul_bundle_index));
} else {
bias_ctype* dst =
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw;
return static_cast<void*>(dst);
}
}
template <typename bias_ctype>
inline void* get_bias_temp_ptr(
const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, size_t bias_bundle_index) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(bias_bundle_index))
: nullptr;
return bias_tmp_ptr;
}
template <typename dst_ctype>
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
size_t pack_oc_size = sparam.pack_oc_size;
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
for (size_t oc = 0; oc < oc_loop; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size *
pack_oc_size);
dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
dst += sparam.ohw * pack_oc_size;
}
}
}
template <typename bias_ctype>
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam,
size_t bias_index) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr = static_cast<bias_ctype*>(
get_bias_temp_ptr<bias_ctype>(param, bundle_thread, bias_index));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
size_t pack_oc_size = sparam.pack_oc_size;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
for (size_t oc = sparam.oc_cur_index / pack_oc_size;
oc < sparam.oc_end_index / pack_oc_size; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size *
pack_oc_size);
copy_dst += sparam.output_block_size * pack_oc_size;
copy_src += sparam.ohw * pack_oc_size;
}
}
}
template <typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void do_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle_thread,
size_t matmul_bundle_index, size_t bias_bundle_index) {
copy_bias<bias_ctype>(param, bundle_thread, sparam, bias_bundle_index);
void* matmul_dst = get_matmul_dst_ptr<bias_ctype>(
param, bundle_thread, sparam, matmul_bundle_index);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
void* bias_temp_ptr = get_bias_temp_ptr<bias_ctype>(param, bundle_thread,
bias_bundle_index);
void* bias_preprocess_ptr = const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index)));
size_t pack_oc_size = sparam.pack_oc_size;
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size / pack_oc_size, 1_z,
sparam.output_block_size, pack_oc_size);
copy_dst<dst_ctype>(param, matmul_dst, sparam);
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode,
......@@ -82,7 +271,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT> : public StrategyBase {
postprocess_mode, PackMode::DEFAULT>
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype,
op_dtype, postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW> {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -92,13 +284,7 @@ public:
Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle,
virtual void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
......@@ -120,16 +306,13 @@ public:
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override;
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam);
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam);
WorkspaceBundle bundle_thread) override {
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode>(param, sparam, bundle_thread,
THREAD_BUNDLE_IM2COL_INDEX,
THREAD_BUNDLE_BIAS_INDEX);
}
void* get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread);
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam);
......@@ -162,7 +345,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK> : public StrategyBase {
postprocess_mode, PackMode::NO_PACK>
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype,
op_dtype, postprocess_mode, PackMode::NO_PACK,
FormatMode::NCHW> {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -173,12 +359,6 @@ public:
Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
......@@ -198,17 +378,6 @@ public:
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam);
inline void* get_bias_temp_ptr(
const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
: nullptr;
return bias_tmp_ptr;
}
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param,
......@@ -216,19 +385,22 @@ public:
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override;
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam);
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam);
WorkspaceBundle bundle_thread) override {
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode>(param, sparam, bundle_thread,
THREAD_BUNDLE_MATMULDST_INDEX,
THREAD_BUNDLE_BIAS_INDEX);
}
};
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase {
postprocess_mode, PackMode::ONLY_PACKA>
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype,
op_dtype, postprocess_mode,
PackMode::ONLY_PACKA,FormatMode::NCHW> {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1;
......@@ -239,12 +411,6 @@ public:
Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam,
......@@ -269,24 +435,15 @@ public:
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam);
inline void* get_bias_temp_ptr(
const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
: nullptr;
return bias_tmp_ptr;
}
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override;
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam);
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam);
WorkspaceBundle bundle_thread) override {
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode>(param, sparam, bundle_thread,
THREAD_BUNDLE_MATMULDST_INDEX,
THREAD_BUNDLE_BIAS_INDEX);
}
};
} // namespace megdnn
......
......@@ -10,85 +10,9 @@
*/
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t PH_SIZE = PH * IW2 * pack_oc_size;
PW = PW * pack_oc_size;
IW = IW * pack_oc_size;
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
batch_id, group_id, channel_id, 1, pack_oc_size));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
......@@ -244,100 +168,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_kern_naked(matmul_param, a_panel, b_panel);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
copy_bias(param, bundle_thread, sparam);
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
void* bias_temp_ptr = get_bias_temp_ptr(param, bundle_thread);
void* bias_preprocess_ptr = const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index)));
size_t pack_oc_size = sparam.pack_oc_size;
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size / pack_oc_size, 1_z,
sparam.output_block_size, pack_oc_size);
copy_dst(param, matmul_dst, sparam);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
size_t pack_oc_size = sparam.pack_oc_size;
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
for (size_t oc = 0; oc < oc_loop; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size *
pack_oc_size);
dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
dst += sparam.ohw * pack_oc_size;
}
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
: nullptr;
return bias_tmp_ptr;
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index;
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size);
copy_dst += sparam.output_block_size;
copy_src += sparam.ohw;
}
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
......
......@@ -11,81 +11,9 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
......@@ -220,81 +148,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
copy_bias(param, bundle_thread, sparam);
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst,
const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index))),
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type,
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z,
sparam.output_block_size);
copy_dst(param, matmul_dst, sparam);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index;
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size);
dst_tmp_ptr += sparam.output_block_size;
dst += sparam.ohw;
}
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index;
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size);
copy_dst += sparam.output_block_size;
copy_src += sparam.ohw;
}
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
......@@ -302,34 +155,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#undef INSTANTIAL_CLASS
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -11,81 +11,9 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
......@@ -120,25 +48,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) {
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
return static_cast<void*>(
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX));
} else {
bias_ctype* dst =
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw;
return static_cast<void*>(dst);
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
......@@ -241,63 +150,19 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index;
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size);
copy_dst += sparam.output_block_size;
copy_src += sparam.ohw;
}
}
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst,
const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index))),
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type,
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z,
sparam.output_block_size);
copy_dst(param, matmul_dst, sparam);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index;
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size);
dst_tmp_ptr += sparam.output_block_size;
dst += sparam.ohw;
}
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) {
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
return static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX));
} else {
bias_ctype* dst =
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw;
return static_cast<void*>(dst);
}
}
......@@ -310,33 +175,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#undef INSTANTIAL_CLASS
} // namespace megdnn
......
......@@ -26,7 +26,7 @@
using namespace megdnn;
using namespace fallback;
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) {
size_t megdnn::fallback::pack_size(param::ConvBias::Format format) {
switch (format) {
case param::ConvBias::Format::NCHW44:
case param::ConvBias::Format::NCHW44_DOT:
......
......@@ -23,8 +23,10 @@ namespace fallback {
/*!
* \brief get the pack_size according to the format
* Note TODO: when remove format from param,
* may using like this "opr::param::format specify"
* */
size_t get_format_pack_size(param::ConvBias::Format format);
size_t pack_size(param::ConvBias::Format format);
/*!
* \brief fallback conv bias forward impl
......
......@@ -52,9 +52,21 @@ class GemmInterleaved<Strategy, true> {
}
size_t get_b_workspace_size() const {
#if __ARM_FEATURE_DOTPROD
size_t new_blockn = m_strategy.block_n;
if (m_strategy.KERNEL_W == 6 && m_strategy.UNROLL_K == 4 &&
m_strategy.KERNEL_H == 8) {
new_blockn = round_up<size_t>((m_strategy.block_n-1) % 6, 4) +
m_strategy.block_n / 6 * 6;
}
size_t N = round_up(new_blockn, m_strategy.KERNEL_W);
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K);
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size;
#else
size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W);
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K);
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size;
#endif
}
//! temporary storage for output, post process such as add bias or relu will
......
......@@ -268,7 +268,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", false);
......
......@@ -72,10 +72,12 @@ std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
bool no_bias = false, bool no_nonlinemode = false,
bool is_input_nchw = false, bool support_full_bias = false,
bool support_sigmoid = false) {
bool is_input_nchw = false, bool is_nchw44_dot = false,
bool support_full_bias = false, bool support_sigmoid = false,
bool only_no_bias = false) {
using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args;
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
......@@ -102,7 +104,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
size_t kernel_h = kernel;
size_t kernel_w = kernel;
param::ConvBias param;
param.format = param::ConvBias::Format::NCHW44;
if (!is_nchw44_dot) {
param.format = param::ConvBias::Format::NCHW44;
} else {
param.format = param::ConvBias::Format::NCHW44_DOT;
}
param.stride_h = stride;
param.stride_w = stride;
param.pad_h = pad;
......@@ -155,18 +161,22 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
if (support_sigmoid) {
nonlinemode.emplace_back(NLMode::SIGMOID);
}
std::vector<megdnn::BiasMode> bias_mode = {
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS};
if (no_bias) {
std::vector<megdnn::BiasMode> bias_mode;
if (!only_no_bias) {
bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS);
if (no_bias) {
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
}
} else {
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
}
if (support_full_bias) {
bias_mode.emplace_back(megdnn::BiasMode::BIAS);
bias_mode.emplace_back(megdnn::BiasMode::BIAS);
}
for (auto bias : bias_mode)
for (auto nlmode : nonlinemode)
for (size_t n : {1, 2})
for (size_t n : {1,2})
for (size_t kernel : kernel_vec)
for (size_t oc : {4, 12})
for (size_t ic : {1, 3, 4, 12})
......@@ -361,19 +371,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false,
false, true, true),
false, false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false,
false, true, true),
false, false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
false, false, true, true),
false, false, false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT");
}
......@@ -1420,6 +1430,111 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
#endif
#undef cb
}
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \
false, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, false, true, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, false, true, false, false, true), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD");
#elif MEGDNN_ARMV7
cb("CONV1x1:AARCH32_INT8_MK4_8X6X4_DOTPROD");
#endif
#undef cb
}
#endif
// clang-format on
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
......@@ -1685,8 +1800,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 4, 7}, 1);
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{2, 4, 7}, 1, false, false, false, false, false, true,true);
#if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7
......@@ -1696,8 +1811,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3, 5, 6}, 2);
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
{3, 5, 6}, 2, false, false, false, false, false, true, true);
#if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7
......
......@@ -897,6 +897,62 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) {
#undef cb
}
#if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS
TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
//! no bias
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel}, TensorShape{});
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
args.emplace_back(
param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1,
(w + 2 * p - kernel) / param.stride_w + 1});
};
for (size_t kernel : {2, 3, 4, 5, 6, 7})
for (size_t ic : {1, 4, 8, 16})
for (size_t oc : {1, 4, 8, 16, 300})
for (size_t p : {0, 2})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY);
Checker<ConvBias> checker(handle());
#define cb(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs( \
{arg.src, arg.filter, arg.bias, {}, {}}); \
}
cb("IM2COLMATMUL:X86_F32_BLAS");
#undef cb
}
#endif
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) {
using namespace conv_bias;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册