提交 b778d225 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mgb/fallback): add conv1x1_gemv, conv1x1 and im2col 8x8x16/8x8x32 support bias

GitOrigin-RevId: 3d97fedc8f33d0b41f94680d6710c56bc32b62e7
上级 c357db01
......@@ -100,7 +100,6 @@ namespace {
MIDOUT_END(); \
break; \
default: \
megdnn_throw("no quantized unsupported biasmode"); \
break; \
}
......@@ -258,6 +257,66 @@ struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR
#undef FOR_BIAS
#define FOR_BINARY_BROADCAST(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_BCAST101>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW);
#define FOR_BINARY_BROADCAST_NCHW44(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x4>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_BINARY(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_bias_mode, OH, OW) \
switch (_bias_mode) { \
case megdnn::BiasMode::NO_BIAS: \
break; \
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \
if (pack_oc_size == 1) { \
FOR_BINARY_BROADCAST(CONCAT_OP(AddOp)); \
} else { \
megdnn_assert(pack_oc_size == 4, \
"Only support nchw44 in ARM"); \
FOR_BINARY_BROADCAST_NCHW44(CONCAT_OP(AddOp)); \
} \
break; \
case megdnn::BiasMode::BIAS: \
FOR_BINARY(CONCAT_OP(AddOp)); \
break; \
default: \
break; \
}
template <typename ctype, typename dtype>
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
megdnn_assert(nonlineMode == megdnn::NonlineMode::IDENTITY);
FOR_BIAS(bias_mode, OH, OW);
}
};
#undef FOR_BINARY_BROADCAST
#undef FOR_BINARY_BROADCAST_NCHW44
#undef FOR_BINARY
#undef FOR_BIAS
#undef CB
#undef CONCAT_OP
#undef CONCAT_NL
......
......@@ -158,9 +158,11 @@ private: \
uint32_t m_tile_size;
enum class PostprocessMode : uint8_t {
FLOAT = 0, ///< support all biasmode and no_nonlinemode
NO_PROCESS, ///<support non bias and identity
QUANTIZED,///<support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish identify nonline mode
FLOAT = 0, ///< support all biasmode and no_nonlinemode
NO_PROCESS, ///< support non bias and identity
QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish
///< identify nonline mode
ADD_BIAS, ///< only add bias
};
} // namespace megdnn
......
......@@ -227,8 +227,7 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
param.dst_type.enumv() == DTypeEnum::QuantizedS16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
......
......@@ -310,6 +310,19 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
} \
} \
MIDOUT_END()
#define cb3(_format, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
conv1x1_gemv_worker = \
Conv1x1GemvWorker<_src_ctype, _bias_ctype, _dst_ctype, \
_bias_ctype, _dst_ctype, \
_postprocess_mode, _format>::exec; \
} \
} \
MIDOUT_END()
switch (param.filter_meta.format) {
case param::ConvBias::Format::NCHW:
......@@ -324,23 +337,23 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash);
#endif
#endif
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32,
dt_int8, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32,
dt_int8, dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
"NCHW::GEMV::INT8x8x32_INT32"_hash);
cb2(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16,
dt_int8, dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int16, dt_int16,
dt_int8, dt_int16, dt_int16, PostprocessMode::ADD_BIAS,
"NCHW::GEMV::INT8x8x16_INT16"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
cb3(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
dt_int32, PostprocessMode::ADD_BIAS,
"NCHW::GEMV::QINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
dt_int8, PostprocessMode::QUANTIZED,
"NCHW::GEMV::QINT8x8x32_QINT8"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
cb3(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
dt_int32, PostprocessMode::ADD_BIAS,
"NCHW::GEMV::QUINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8, dt_int32,
......@@ -365,13 +378,13 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
break;
case param::ConvBias::Format::NCHW44_DOT:
cb2(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32,
cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32,
dt_int32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
PostprocessMode::ADD_BIAS,
"NCHW44_DOT::GEMV::INT8x8x32_INT32"_hash);
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
cb3(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
dt_int32, PostprocessMode::ADD_BIAS,
"NCHW44_DOT::GEMV::QINT8x8x32_QINT32"_hash);
cb2(param::ConvBias::Format::NCHW44_DOT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
......@@ -385,6 +398,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
}
#undef cb1
#undef cb2
#undef cb3
megdnn_assert(conv1x1_gemv_worker, "No suitable gemv worker");
......@@ -448,8 +462,7 @@ bool ConvBiasImpl::AlgoConv1x1Gemv::usable(const NCBKernSizeParam& param,
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
......
......@@ -56,6 +56,19 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
} \
} \
MIDOUT_END()
#define cb3(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique<Conv1x1Strategy< \
_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \
_dst_ctype, _postprocess_mode, _packmode>>(pack_c_size); \
} \
} \
MIDOUT_END()
switch (pack_mode) {
case MatrixMulImpl::AlgoBase::PackMode::DEFAULT:
......@@ -71,26 +84,26 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
"Default::FLOAT16_FLOAT16"_hash);
#endif
#endif
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32,
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int32,
dt_int32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS, "Default::INT8x8x32_INT32"_hash);
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16,
PostprocessMode::ADD_BIAS, "Default::INT8x8x32_INT32"_hash);
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dt_int8, dt_int16,
dt_int16, dt_int8, dt_int16, dt_int16,
PostprocessMode::NO_PROCESS, "Default::INT8x8x16_INT16"_hash);
PostprocessMode::ADD_BIAS, "Default::INT8x8x16_INT16"_hash);
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT,
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT,
dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
PostprocessMode::ADD_BIAS,
"Default::QUINT8x8x32_QINT32"_hash);
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT,
dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
PostprocessMode::QUANTIZED, "Default::QUINT8x8x32_QUINT8"_hash);
#endif
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8,
cb3(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
dt_int32, PostprocessMode::ADD_BIAS,
"Default::QINT8x8x32_QINT32"_hash);
cb2(MatrixMulImpl::AlgoBase::PackMode::DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dt_int32,
......@@ -107,17 +120,17 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
cb1(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_float32,
dt_float32, PostprocessMode::FLOAT, "NoPack::FLOAT"_hash);
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16,
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int16,
dt_int16, dt_int8, dt_int16, dt_int16,
PostprocessMode::NO_PROCESS, "NoPack::INT8x8x16_INT16"_hash);
PostprocessMode::ADD_BIAS, "NoPack::INT8x8x16_INT16"_hash);
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32,
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dt_int8, dt_int32,
dt_int32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS, "NoPack::INT8x8x32_INT32"_hash);
PostprocessMode::ADD_BIAS, "NoPack::INT8x8x32_INT32"_hash);
cb2(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8,
cb3(MatrixMulImpl::AlgoBase::PackMode::NO_PACK, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dt_int32,
dt_int32, PostprocessMode::NO_PROCESS,
dt_int32, PostprocessMode::ADD_BIAS,
"NoPack::QINT8x8x32_QINT32"_hash);
break;
......@@ -127,6 +140,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
}
#undef cb1
#undef cb2
#undef cb3
megdnn_throw("Invalid Data Type");
return nullptr;
}
......@@ -207,4 +221,4 @@ bool Conv1x1Factory::can_make_conv1x1_strategy(
} // namespace fallback
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
......@@ -746,8 +746,7 @@ bool ConvBiasImpl::AlgoIm2col::usable(
if (param.dst_type.enumv() == DTypeEnum::Int16 ||
param.dst_type.enumv() == DTypeEnum::Int32 ||
param.dst_type.enumv() == DTypeEnum::QuantizedS32) {
if (param.bias_mode != megdnn::BiasMode::NO_BIAS ||
param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
if (param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false;
}
}
......
......@@ -213,6 +213,22 @@ public:
} \
MIDOUT_END(); \
return {};
#define cb3(_format, _packmode, _i_src_type, _i_bias_type, _i_dst_type, \
_src_ctype, _bias_ctype, _dst_ctype, _postprocess_mode, \
_midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_im2col_factory_make_strategy, \
midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique< \
Strategy<_src_ctype, _bias_ctype, _dst_ctype, _bias_ctype, \
_dst_ctype, _postprocess_mode, \
PackMode::_packmode, FormatMode::_format>>(); \
} \
} \
MIDOUT_END(); \
return {};
static std::unique_ptr<StrategyBase> make_default_strategy(
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
......@@ -279,13 +295,13 @@ public:
#endif
case StrategyType::INT8x8x32:
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
cb3(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
"DefaultStrategyType::INT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
cb3(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
"DefaultStrategyType::INT8x8x32"_hash);
} else {
megdnn_throw(
......@@ -299,12 +315,12 @@ public:
case StrategyType::INT8x8x16:
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
cb3(NCHW, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::ADD_BIAS,
"DefaultStrategyType::INT8x8x16"_hash);
} else if (format == param::ConvBias::Format::NCHW44) {
cb2(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
cb3(NCHW44, DEFAULT, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::ADD_BIAS,
"DefaultStrategyType::INT8x8x16"_hash);
} else {
megdnn_throw(
......@@ -316,9 +332,9 @@ public:
break;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
cb3(NCHW, DEFAULT, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
PostprocessMode::ADD_BIAS,
"DefaultStrategyType::QUINT8x8x32"_hash);
break;
......@@ -331,15 +347,15 @@ public:
#endif
case StrategyType::QINT8x8x32:
if (format == param::ConvBias::Format::NCHW) {
cb2(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
cb3(NCHW, DEFAULT, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
PostprocessMode::ADD_BIAS,
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
cb3(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
} else {
megdnn_throw(
......@@ -467,13 +483,13 @@ public:
#endif
#endif
case StrategyType::INT8x8x16:
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
cb3(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::ADD_BIAS,
"NoPackStrategyType::INT8x8x16"_hash);
break;
case StrategyType::INT8x8x32:
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
cb3(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::ADD_BIAS,
"NoPackStrategyType::INT8x8x32"_hash);
break;
default:
......@@ -509,6 +525,7 @@ public:
#undef cb1
#undef cb2
#undef cb3
static std::unique_ptr<StrategyBase> make_strategy(
fallback::MatrixMulImpl::AlgoBase* matmul_algo,
......
......@@ -203,18 +203,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::ADD_BIAS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
megdnn::PostprocessMode::ADD_BIAS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
megdnn::PostprocessMode::ADD_BIAS)
#undef INSTANTIAL_CLASS
} // namespace megdnn
......
......@@ -119,19 +119,16 @@ INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::ADD_BIAS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
megdnn::PostprocessMode::ADD_BIAS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
megdnn::PostprocessMode::ADD_BIAS)
#undef INSTANTIAL_CLASS
} // namespace megdnn
......
......@@ -162,9 +162,9 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
megdnn::PostprocessMode::ADD_BIAS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
megdnn::PostprocessMode::ADD_BIAS)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#else
#if !MEGDNN_DISABLE_FLOAT16
......
......@@ -294,6 +294,73 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
#undef FOR_BIAS
}
};
#undef CALL_BINARY
#undef CALL_BINARY_BROADCAST
#define CALL_BINARY(_op, _simd_type) \
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \
DType, size_t)> \
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \
megdnn::x86::BcastType::VEC_VEC>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, \
N* OC* OH* OW);
#define CALL_BINARY_BROADCAST(_op, _simd_type) \
thin_function<void(const ctype*, const ctype*, dtype*, DType, DType, \
DType, size_t, size_t, size_t)> \
run = OpCallerBinary<_op<_simd_type, ctype, dtype>, _simd_type, \
megdnn::x86::BcastType::VEC_BCAST101>::run; \
run(static_cast<ctype*>(conv_dst_ptr), static_cast<ctype*>(bias_ptr), \
reinterpret_cast<dtype*>(dst_ptr), bias_type, bias_type, dst_type, N, \
OC, OH* OW);
#define FOR_SIMD(CALLER) \
if (is_supported(SIMDType::AVX2)) { \
CALLER(AddOp, SIMDType::AVX2) \
} else if (is_supported(SIMDType::SSE4_2)) { \
CALLER(AddOp, SIMDType::SSE4_2) \
} else { \
CALLER(AddOp, SIMDType::NONE) \
}
#define FOR_BIAS(bias_mode) \
switch (bias_mode) { \
case BiasMode::BIAS: \
FOR_SIMD(CALL_BINARY); \
break; \
case BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_SIMD(CALL_BINARY_BROADCAST); \
break; \
default: \
break; \
}
template <typename ctype, typename dtype>
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBiasV0::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn_assert(pack_oc_size == 1,
"PostProcess only support nchw in x86");
megdnn_assert(
nonlineMode == megdnn::param::ConvBiasV0::NonlineMode::IDENTITY,
"Add bias PostProcess only support IDENTITY");
if (bias_mode == megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
return;
}
FOR_BIAS(bias_mode);
#undef CALL_BINARY
#undef CALL_BINARY_BROADCAST
#undef FOR_SIMD
#undef FOR_BIAS
}
};
#undef cb_unary
#undef cb_binary
#undef BIAS_CASE
......
......@@ -92,6 +92,8 @@ OP(dt_int8, SIMDType::AVX2, "avx2", __m256i, __m256ix2, __m256i, mm256, epi8,
using AddOpBase::operator(); \
};
OP(dt_int32, SIMDType::NONE);
OP(dt_int16, SIMDType::NONE);
OP(dt_float32, SIMDType::NONE);
#undef OP
} // namespace x86
......
......@@ -1992,13 +1992,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, false, true, false, false, true), \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \
true, false, true, true, false, false), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, true), \
true, false, false), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name);
......@@ -2041,13 +2041,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) {
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, false, true, false, false, true), \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, \
true, false, true, true, false, false), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, true), \
true, false, false), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name);
......@@ -2118,7 +2118,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
NormalRNG rng(128.f);
#define cb(name) \
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
false, true, true), \
......@@ -2188,18 +2187,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
checker_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name); \
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
&rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name);
#define cb(name) \
checker_conv_bias(get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, true, false), \
handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name); \
checker_conv_bias( \
get_conv_bias_args({1}, 2, false, false, true, true, false), \
handle(), &rng, epsilon, \
dtype::Quantized8Asymm(1.2f, (uint8_t)125), \
dtype::Quantized8Asymm(1.3f, (uint8_t)129), \
dtype::QuantizedS32(1.2 * 1.3), {}, name);
#if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
......@@ -2252,18 +2252,18 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args_nchw44 =
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, true, true,
false, false, false, false, true);
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, true, false, true,
false, false, true, false, false);
std::vector<conv_bias::TestArg> args_nchw44_1x1s2 =
get_nchw44_conv_bias_args({1}, 2, true, true, true, false, false,
false, false, true);
#define cb(name) \
checker_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name); \
checker_conv_bias(get_conv_bias_args({1}, 2, false, true, true), handle(), \
&rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
get_nchw44_conv_bias_args({1}, 2, true, false, true, false, false,
true, false, false);
#define cb(name) \
checker_conv_bias( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name); \
checker_conv_bias(get_conv_bias_args({1}, 2, false, false, true), \
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name);
#define cb_nchw44(name) \
......@@ -2314,14 +2314,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_FILTERPREPROCES
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_IM2COLMATMUL_INT8x8x16_NOPACK_FILTERPREPROCESS) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
#define cb(name) \
check_conv_bias_preprocess( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, true, true), \
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name); \
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, true, true), \
handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \
#define cb(name) \
check_conv_bias_preprocess( \
get_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, true), \
handle(), &rng, epsilon, dtype::Int8{}, dtype::Int8{}, \
dtype::Int16{}, dtype::Int16{}, name); \
check_conv_bias_preprocess(get_conv_bias_args({1}, 2, false, false, true), \
handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, \
name);
#if MEGDNN_AARCH64
......@@ -2406,7 +2406,7 @@ void checker_conv_bias_mul_int8x8x32(std::vector<conv_bias::TestArg> args,
checker.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, {})
.set_dtype(4, dtype::QuantizedS32(6.25f))
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
......@@ -2436,7 +2436,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args
checker.set_dtype(0, dtype::QuantizedS8(2.5f))
.set_dtype(1, dtype::QuantizedS8(2.5f))
.set_dtype(2, dtype::QuantizedS32(6.25f))
.set_dtype(4, {})
.set_dtype(4, dtype::QuantizedS32(6.25f))
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
......@@ -2450,7 +2450,7 @@ void checker_conv_bias_int8x8x32_preprocess(std::vector<conv_bias::TestArg> args
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true);
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64
......@@ -2464,7 +2464,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, true, true);
get_nchw44_conv_bias_args({2, 5, 7}, 2, false, false, true);
#define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name);
#if MEGDNN_AARCH64
......@@ -2478,7 +2478,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPR
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S1) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, true, true);
get_nchw44_conv_bias_args({3, 4, 6}, 1, false, false, true);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
#if MEGDNN_AARCH64
......@@ -3080,9 +3080,10 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16) {
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
std::vector<conv_bias::TestArg> args =
get_conv_bias_1x1_args(false, true, false, false);
std::vector<conv_bias::TestArg> args_nchw44 = get_nchw44_conv_bias_args(
{1}, 1, true, true, true, false, false, false, false, true);
{1}, 1, true, true, true, false, false, true, false, false);
#define cb(name) \
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{}, \
dtype::Int8{}, dtype::Int16{}, dtype::Int16{}, name);
......@@ -3140,7 +3141,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_1X1_S1_INT8x8x16_PREPROCESS) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
std::vector<conv_bias::TestArg> args =
get_conv_bias_1x1_args(false, true, false, false);
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);
......
......@@ -834,6 +834,13 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8X8X32) {
//! no bias
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel}, TensorShape{});
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, (h + 2 * p - kernel) + 1,
(h + 2 * p - kernel) + 1});
};
for (size_t kernel : {2, 3, 4, 5, 6, 7})
......@@ -1384,7 +1391,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32) {
using namespace conv_bias;
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true);
#if MEGDNN_X86_WITH_MKL_DNN
if (x86::is_supported(x86::SIMDType::VNNI)) {
checker_conv_bias(args, handle(), &rng, epsilon, dtype::Int8{},
......@@ -1422,7 +1429,7 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_CONV1X1_S1_INT8X8X32_PREPROCESS) {
using namespace conv_bias;
UniformIntRNG rng{-50, 50};
float epsilon = 0.001;
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(true, true);
std::vector<conv_bias::TestArg> args = get_conv_bias_1x1_args(false, true);
#if MEGDNN_X86_WITH_VNNI
if (x86::is_supported(x86::SIMDType::VNNI)) {
checker_conv_bias_preprocess(args, handle(), &rng, epsilon, dtype::Int8{},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册