提交 e1e56988 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): add conv1x1 filter preprocess funciton

GitOrigin-RevId: 4bd109f2daad00b2aa56eb8e1eaaded9c3a571cb
上级 0f9dec68
......@@ -171,7 +171,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoConv1x1::dispatch_kerns(
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT ||
pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
ret_kern.push_back({kern_packA, {GROUP, oc_blocks_per_group}});
//! if enable filter preprocess kern_packA should not dispatch
if (!is_enable_filter_preprocess(param)) {
ret_kern.push_back({kern_packA, {GROUP, oc_blocks_per_group}});
}
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) {
ret_kern.push_back({kern_packB, {1}});
}
......@@ -200,6 +203,13 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
return false;
}
//! only matmul's packmode is packa or default support weight preprocess
if (is_enable_filter_preprocess(param) &&
(m_matmul_algo->packmode() ==
fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK)) {
return false;
}
if (param.src_type.enumv() != DTypeEnum::Int8 &&
param.src_type.enumv() != DTypeEnum::QuantizedS8 &&
param.src_type.enumv() != DTypeEnum::Quantized8Asymm &&
......@@ -253,6 +263,122 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
return false;
}
SmallVector<TensorLayout>
ConvBiasImpl::AlgoConv1x1::deduce_preprocessed_filter_layout(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_fallback_conv1x1,
midout_iv(
"ConvBiasImpl::AlgoConv1x1::deduce_preprocessed_filter_layout"_hash)) {
fallback::MatrixMulImpl::AlgoBase::MatmulDescription matmul_desc =
m_matmul_algo->matmul_description();
bool default_pack = matmul_desc.packmode ==
MatrixMulImpl::AlgoBase::PackMode::DEFAULT;
bool only_packA = matmul_desc.packmode ==
MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA;
//! only support default_pack and only_packa mode
if (matmul_desc.packmode ==
MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
return {};
}
size_t OH = param.osz[0];
size_t OW = param.osz[1];
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
auto matmul_param = utils::get_matmul_kern_param(param, OH * OW,
compt_oc_block_size);
WorkspaceBundle wb(nullptr, {});
if (default_pack) {
Conv1x1Kerns<MatrixMulImpl::AlgoBase::PackMode::DEFAULT> dispatcher;
wb = dispatcher.get_bundle(param, matmul_param, m_matmul_algo,
compt_oc_block_size);
} else if (only_packA) {
Conv1x1Kerns<MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA>
dispatcher;
wb = dispatcher.get_bundle(param, matmul_param, m_matmul_algo,
compt_oc_block_size);
}
size_t GROUP = param.filter_meta.group;
SmallVector<TensorLayout> preprocessed_layouts;
preprocessed_layouts.push_back(
{{GROUP, wb.get_size(0)}, dtype::Int8()});
return preprocessed_layouts;
}
MIDOUT_END();
return {};
}
SmallVector<ConvBiasImpl::NCBKern>
ConvBiasImpl::AlgoConv1x1::dispatch_preprocess_kerns(
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(
megdnn_fallback_conv1x1,
midout_iv(
"ConvBiasImpl::AlgoConv1x1::dispatch_preprocess_kerns"_hash)) {
SmallVector<ConvBiasImpl::NCBKern> ret_kern;
size_t OH = param.osz[0];
size_t OW = param.osz[1];
size_t OC = param.filter_meta.ocpg;
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
size_t GROUP = param.filter_meta.group;
size_t oc_blocks_per_group = div_ceil(OC, compt_oc_block_size);
auto matmul_param = utils::get_matmul_kern_param(param, OH * OW,
compt_oc_block_size);
WorkspaceBundle whole_bundle = {nullptr, {}};
WorkspaceBundle matmul_bundle = {nullptr, {}};
auto pack_mode = m_matmul_algo->packmode();
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::DEFAULT) {
MIDOUT_BEGIN(megdnn_fallback_conv1x1,
midout_iv("get_defaul_matmul_packmode_bundle"_hash)) {
Conv1x1Kerns<MatrixMulImpl::AlgoBase::PackMode::DEFAULT>
dispatcher;
whole_bundle = dispatcher.get_bundle(param, matmul_param,
m_matmul_algo,
compt_oc_block_size);
matmul_bundle = m_matmul_algo->get_bundle(matmul_param);
}
MIDOUT_END();
} else if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA) {
MIDOUT_BEGIN(
megdnn_fallback_conv1x1,
midout_iv("get_onlypacka_matmul_packmode_bundle"_hash)) {
Conv1x1Kerns<MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA>
dispatcher;
whole_bundle = dispatcher.get_bundle(param, matmul_param,
m_matmul_algo,
compt_oc_block_size);
matmul_bundle = m_matmul_algo->get_bundle(matmul_param);
}
MIDOUT_END();
} else {
//! if nopack return null so that OprWeightPreprocessProxy can run
//! with nopack mode
return {};
}
Conv1x1StrategyBase* conv1x1_strategy =
Conv1x1Factory::make_conv1x1_strategy(param, pack_mode,
param.filter_meta.format);
auto kern_packA = [this, whole_bundle, matmul_bundle, param,
compt_oc_block_size, conv1x1_strategy](
const NCBKernParam& ncb_param,
const NCBKernIndex& ncb_index) mutable {
conv1x1_strategy->packA(whole_bundle, matmul_bundle,
compt_oc_block_size, this->m_matmul_algo,
param, ncb_param, std::move(ncb_index));
};
ret_kern.push_back({kern_packA, {GROUP, oc_blocks_per_group}});
return ret_kern;
}
MIDOUT_END();
return {};
}
bool ConvBiasImpl::AlgoConv1x1::is_preferred(
const NCBKernSizeParam& param) const {
size_t OH = param.osz[0];
......
......@@ -41,6 +41,15 @@ public:
const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam&) const override;
SmallVector<TensorLayout> deduce_preprocessed_filter_layout(
const NCBKernSizeParam& param) const override;
size_t get_preprocess_workspace(
const NCBKernSizeParam& /*param*/) const override {
return 0;
}
SmallVector<NCBKern> dispatch_preprocess_kerns(
const NCBKernSizeParam& param) const override;
protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;
......
......@@ -35,7 +35,6 @@ public:
auto matmul_bundle = matmul_algo->get_bundle(matmul_param);
auto thread_bundle = utils::get_thread_bundle(param, matmul_bundle.get_size(2),
oc_tile_size);
//! size per thread
size_t all_threads_bytes =
thread_bundle.total_size_in_bytes() * param.nr_threads;
......@@ -44,7 +43,9 @@ public:
size_t packa_bytes_per_oc_tile = matmul_bundle.get_size(0);
size_t oc_tiles_per_group = div_ceil(OC, oc_tile_size);
size_t all_packa_bytes =
packa_bytes_per_oc_tile * oc_tiles_per_group * GROUP;
is_enable_filter_preprocess(param)
? 0
: packa_bytes_per_oc_tile * oc_tiles_per_group * GROUP;
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::ONLY_PACKA)
return WorkspaceBundle{nullptr,
......
......@@ -106,9 +106,14 @@ public:
size_t numbers_offset_of_filter =
oc_tile_size * IC * oc_tile_id_in_group;
src_ctype* a_panel = reinterpret_cast<src_ctype*>(
reinterpret_cast<int8_t*>(whole_bundle.get(0)) +
bytes_offset_of_a_panel);
int8_t* tmp_ptr =
is_enable_filter_preprocess(param)
? static_cast<int8_t*>(
param.preprocessed_filter->tensors[0].raw_ptr)
: static_cast<int8_t*>(whole_bundle.get(0));
src_ctype* a_panel =
reinterpret_cast<src_ctype*>(tmp_ptr + bytes_offset_of_a_panel);
matmul_kern_param.A_ptr = const_cast<src_ctype*>(
ncb_param.filter<src_ctype>(group_id) +
......@@ -206,8 +211,14 @@ public:
size_t bytes_offset_of_a_panel =
group_id * packa_bytes_per_group +
oc_tile_id_in_group * packa_bytes_per_oc_tile;
int8_t* a_panel = reinterpret_cast<int8_t*>(whole_bundle.get(0)) +
bytes_offset_of_a_panel;
int8_t* tmp_ptr =
is_enable_filter_preprocess(param)
? static_cast<int8_t*>(
param.preprocessed_filter->tensors[0].raw_ptr)
: static_cast<int8_t*>(whole_bundle.get(0));
int8_t* a_panel = tmp_ptr + bytes_offset_of_a_panel;
size_t bytes_offset_of_b_panel =
batch_id * packb_bytes_per_group * GROUP +
......
......@@ -2724,7 +2724,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
#define cb(name) \
check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \
dtype::Float32(), dtype::Float32(), \
dtype::Float32(), dtype::Float32(), name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_F32K8X12X1:24");
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_F32:48");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
......@@ -2741,7 +2756,21 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false);
#define cb(name) \
check_conv_bias_preprocess(args, handle(), nullptr, 0.001, \
dtype::Float32(), dtype::Float32(), \
dtype::Float32(), dtype::Float32(), name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_F32_MK4_PACK_4X12:24");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
......@@ -2780,6 +2809,22 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
}
check_conv_bias(gemv_args, handle(), "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
NormalRNG rng(1);
#if MEGDNN_AARCH64
check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{},
dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
"CONV1x1:AARCH64_F16_K8X24X1:48");
#elif MEGDNN_ARMV7
check_conv_bias_preprocess(args, handle(), &rng, 0.03, dtype::Float16{},
dtype::Float16{}, dtype::Float16{}, dtype::Float16{},
"CONV1x1:AARCH32_F16_K4X16X1:24");
#endif
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
......@@ -2814,6 +2859,31 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
"CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args =
get_conv_bias_1x1_args(false, false, true, true);
#define cb(name) \
check_conv_bias_preprocess( \
args, handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
#if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
#else
cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:48");
#endif
#elif MEGDNN_ARMV7
epsilon = 1;
cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:48");
#endif
#undef cb
}
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
UniformIntRNG rng{-50, 50};
......@@ -2849,6 +2919,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
dtype::Quantized8Asymm(50.3f, (uint8_t)120),
"CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM_PREPROCESS) {
UniformIntRNG rng{-50, 50};
std::vector<conv_bias::TestArg> args =
get_conv_bias_1x1_args(false, false, true, true);
#define cb(name) \
check_conv_bias_preprocess(args, handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), \
dtype::Quantized8Asymm(50.3f, (uint8_t)120), \
name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
#else
cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
#endif
#elif MEGDNN_ARMV7
epsilon = 1;
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:48");
#endif
#undef cb
}
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
......@@ -2887,6 +2983,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
dtype::QuantizedS32(1.2 * 1.3), {}, "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) {
NormalRNG rng(128.f);
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
#define cb(name) \
check_conv_bias_preprocess(args, handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name);
#if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
#else
cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
#endif
#elif MEGDNN_ARMV7
#if __ARM_FEATURE_DOTPROD
cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
#endif
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
......@@ -2924,6 +3046,28 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
#define cb(name) \
check_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \
name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X16_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X16_K4X4X16:24");
cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X16_K4X8X8:24");
cb("CONV1x1:ARMV7_INT8X8X16_K4X2X16:48");
cb("CONV1x1:ARM_COMMON_INT8X8X16:24");//!add nopack test
#endif
#undef cb
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
......@@ -2959,6 +3103,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name);
#if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
#else
cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
#endif
#elif MEGDNN_ARMV7
#if __ARM_FEATURE_DOTPROD
cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
#endif
cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
#endif
#if MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X32_K4X2X16:48");
#endif
#undef cb
}
#ifndef __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
using namespace conv_bias;
......@@ -2988,6 +3158,36 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, true, true);
#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
#endif
#undef cb
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
check_conv_bias_preprocess(get_nchw44_conv_bias_args({1}, 1, true, false, false), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X32_MK4_4X4X16:24");
#elif MEGDNN_ARMV7
cb("CONV1x1:ARMV7_INT8X8X32_MK4_4X2X16:24");
#endif
#undef cb
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) {
......
......@@ -1321,6 +1321,31 @@ void checker_conv_bias(std::vector<conv_bias::TestArg> args, Handle* handle,
{arg.src, arg.filter, arg.bias, {}, {}});
}
}
void checker_conv_bias_preprocess(std::vector<conv_bias::TestArg> args, Handle* handle,
RNG* rng, float epsilon, DType type0, DType type1,
DType type2, DType type3, const char* algo_name) {
using namespace conv_bias;
Checker<ConvBiasForward, OprWeightPreprocessProxy<ConvBiasForward>> checker(
handle);
checker.set_before_exec_callback(
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name));
checker.set_dtype(0, type0);
checker.set_dtype(1, type1);
checker.set_dtype(2, type2);
checker.set_dtype(4, type3);
checker.set_epsilon(epsilon);
if (NULL != rng) {
checker.set_rng(0, rng).set_rng(1, rng).set_rng(2, rng).set_rng(3, rng);
}
for (auto&& arg : args) {
checker.set_param(arg.param).execs(
{arg.src, arg.filter, arg.bias, {}, {}});
}
}
} // namespace
#if MEGDNN_X86_WITH_MKL
......@@ -1330,14 +1355,32 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_PACKA) {
check_conv_bias(args, handle(), "CONV1x1:X86_F32_MKL_PACKA:24");
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_PACKA_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
checker_conv_bias_preprocess(args, handle(), nullptr, 0.001,
dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, dtype::Float32{},
"CONV1x1:X86_F32_MKL_PACKA:24");
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
check_conv_bias(args, handle(), "CONV1x1:X86_F32_BLAS:48");
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_FP32_BLAS_NOPACK_REPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, false);
checker_conv_bias_preprocess(args, handle(), nullptr, 0.001,
dtype::Float32{}, dtype::Float32{},
dtype::Float32{}, dtype::Float32{},
"CONV1x1:X86_F32_BLAS:24");
}
#endif
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) {
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) {
using namespace conv_bias;
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
......@@ -1374,6 +1417,38 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X) {
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1:X86_INT8X8X16_SSE");
}
TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) {
using namespace conv_bias;
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
#if MEGDNN_X86_WITH_VNNI
if (x86::is_supported(x86::SIMDType::VNNI)) {
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int32{}, dtype::Int32{},
"CONV1x1:X86_INT8X8X32_VNNI:24");
}
#endif
if (x86::is_supported(x86::SIMDType::AVX2)) {
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int32{}, dtype::Int32{},
"CONV1x1:X86_INT8X8X32_AVX2_4X16X2:24");
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int32{}, dtype::Int32{},
"CONV1x1:X86_INT8X8X32_AVX2_2X4X16:24");
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1:X86_INT8X8X16_AVX2");
}
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int32{}, dtype::Int32{},
"CONV1x1:X86_INT8X8X32_SSE_4X8X2:48");
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
dtype::Int8{}, dtype::Int16{}, dtype::Int16{},
"CONV1x1:X86_INT8X8X16_SSE");
}
/************************* End Conv1x1 PackA ************************/
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册