提交 cdefe90e 编写于 作者: M Megvii Engine Team

feat(dnn/fallback): support mk4 fp32 conv1x1

GitOrigin-RevId: 301ef0137f61d07c3f6d0a6ada189ca0274921dc
上级 980ebf2c
......@@ -49,6 +49,14 @@ namespace {
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW);
#define FOR_NONLINEAR_BINARY_BROADCAST_NCHW44(_op) \
megdnn::arm_common::OpCallerBinary<_op<ctype>, \
megdnn::arm_common::VEC_BCAST101x4>:: \
run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size);
#define FOR_NONLINEAR_BINARY(_op) \
megdnn::arm_common:: \
OpCallerBinary<_op<ctype>, megdnn::arm_common::VEC_VEC>::run( \
......@@ -63,7 +71,13 @@ namespace {
FOR_NONLINEAR_NOBIAS(FOR_NONLINEAR_UNARY) \
break; \
case megdnn::BiasMode::BROADCAST_CHANNEL_BIAS: \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST) \
if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, \
"Only support nchw44 in ARM"); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \
} \
break; \
case megdnn::BiasMode::BIAS: \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY) \
......@@ -129,6 +143,7 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
#undef FOR_NONLINEAR_UNARY
#undef FOR_NONLINEAR_BINARY_BROADCAST
#undef FOR_NONLINEAR_BINARY_BROADCAST_NCHW44
#undef FOR_NONLINEAR_BINARY
#undef FOR_NONLINEAR_NOBIAS
#undef FOR_NONLINEAR
......@@ -187,6 +202,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
if (pack_oc_size == 1) { \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST); \
} else { \
megdnn_assert(pack_oc_size == 4, \
"Only support nchw44 in ARM"); \
FOR_NONLINEAR(FOR_NONLINEAR_BINARY_BROADCAST_NCHW44); \
} \
break; \
......
......@@ -216,14 +216,18 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
param.nonlineMode != megdnn::NonlineMode::IDENTITY)
return false;
if (opr->param().format == param::ConvBias::Format::NCHW44) {
//! nchw44 hybird mode and channel wise is not support
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) {
return false;
}
}
size_t OH = param.osz[0];
size_t OW = param.osz[1];
MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, OH * OW, get_oc_tile_size_heuristic(param));
if(opr->param().format == param::ConvBias::Format::NCHW44)
matmul_param.format = param::MatrixMul::Format::MK4;
MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param(
param, OH * OW, get_oc_tile_size_heuristic(param));
bool matmul_usable = m_matmul_algo->usable(matmul_param);
return matmul_usable &&
......
......@@ -22,6 +22,20 @@ namespace conv1x1 {
namespace {
size_t get_format_pack_size(param::ConvBias::Format format) {
switch(format){
case param::ConvBias::Format::NCHW44:
case param::ConvBias::Format::NCHW4:
return 4_z;
case param::ConvBias::Format::NCHW88:
return 8_z;
case param::ConvBias::Format::NCHW:
return 1_z;
default:
megdnn_throw("unknow pack size of the format");
}
}
struct StrategyHashParam {
ConvBiasImpl::NCBKernSizeParam param;
param::ConvBias::Format format;
......@@ -71,7 +85,7 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format) {
size_t pack_size = format == param::ConvBias::Format::NCHW ? 1 : 4;
size_t pack_size = get_format_pack_size(format);
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \
......
......@@ -41,19 +41,25 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param(
param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t pack_c_size = 1_z;
auto format = param::MatrixMul::Format::DEFAULT;
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){
pack_c_size = 4_z;
format = param::MatrixMul::Format::MK4;
}
return {param.filter_type,
param.src_type,
is_dst_8bit ? param.bias_type : param.dst_type,
M,
N,
K,
LDA,
LDB,
LDC,
LDA * pack_c_size,
LDB * pack_c_size,
LDC * pack_c_size,
false,
false,
param::MatrixMul::ComputeMode::DEFAULT,
param::MatrixMul::Format::DEFAULT};
format};
}
} // namespace
......@@ -138,8 +144,6 @@ public:
reinterpret_cast<int8_t*>(whole_bundle.get(0)) +
bytes_offset_of_a_panel);
matmul_kern_param.LDA *= m_pack_size;
matmul_kern_param.A_ptr = const_cast<src_ctype*>(
ncb_param.filter<src_ctype>(group_id) +
numbers_offset_of_filter);
......@@ -172,7 +176,6 @@ public:
static_cast<MatrixMulImpl::KernSizeParam&>(matmul_kern_param) =
get_matmul_kern_param(param, OH * OW, OC);
matmul_kern_param.LDB *= m_pack_size;
rep(batch, BATCH) {
rep(g, GROUP) {
......@@ -282,8 +285,6 @@ public:
matmul_kern_param.C_ptr = matmul_dst;
matmul_kern_param.LDC *= m_pack_size;
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
auto matmul_kern = matmul_algo->get_kern(matmul_kern_param);
matmul_kern(matmul_kern_param);
......@@ -295,13 +296,14 @@ public:
//! do postprocess
void* bias_ptr = nullptr;
if (param.bias_mode == megdnn::BiasMode::BIAS)
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) +
numbers_of_ncb_dst_offset));
else
} else {
bias_ptr = static_cast<void*>(const_cast<bias_ctype*>(
ncb_param.bias<bias_ctype>(batch_id, group_id) + oc_start));
}
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_ptr, conv_bias_dst, param.bias_mode,
......
......@@ -137,8 +137,8 @@ class ConvBias {
sizeof(output_compute_type) *
std::max(Strategy::IC_BLOCK_SIZE, Strategy::OC_BLOCK_SIZE);
size_t matmul_workspace_size =
matmul_algo->get_workspace(get_matmul_kern_param(param));
size_t matmul_workspace_size = matmul_algo->get_workspace(
get_matmul_kern_param(param, m_unit_oc_size));
//! compute workspace is independent and separated as far as possible
//! in case of false cache line sharing
......@@ -384,7 +384,7 @@ public:
get_wbundle_compute(param, matmul_algo);
fallback::MatrixMulImpl::KernParam matmul_param;
static_cast<fallback::MatrixMulImpl::KernSizeParam&>(matmul_param) =
get_matmul_kern_param(param);
get_matmul_kern_param(param, m_unit_oc_size);
Strategy strategy = m_strategy;
size_t unit_tile_size = m_unit_tile_size;
......@@ -450,21 +450,24 @@ public:
}
fallback::MatrixMulImpl::KernSizeParam get_matmul_kern_param(
const NCBKernSizeParam& param) const {
const NCBKernSizeParam& param, size_t nr_oc_in_unit = 0) const {
size_t M = 0;
size_t N = 0;
size_t K = 0;
size_t LDA = 0, LDB = 0, LDC = 0;
if (nr_oc_in_unit == 0) {
nr_oc_in_unit = param.filter_meta.ocpg;
}
if (format == param::MatrixMul::Format::DEFAULT) {
M = m_unit_tile_size;
N = param.filter_meta.ocpg;
N = nr_oc_in_unit;
K = param.filter_meta.icpg;
LDA = K;
LDB = N;
LDC = N;
} else {
M = param.filter_meta.ocpg;
M = nr_oc_in_unit;
N = m_unit_tile_size;
K = param.filter_meta.icpg;
megdnn_assert(K % Strategy::IC_BLOCK_SIZE == 0, "invalid K: %zu",
......
......@@ -126,6 +126,8 @@ struct PostProcess {
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn_assert(pack_oc_size == 1,
"PostProcess only support nchw in x86");
megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
......@@ -149,38 +151,6 @@ struct PostProcess {
}
};
template <typename ctype, typename dtype>
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::FLOAT> {
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
megdnn::ConvBiasForward::BiasMode bias_mode,
megdnn::param::ConvBias::NonlineMode nonlineMode,
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW, size_t pack_oc_size=1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
switch (nonlineMode) {
BIAS_CASE(RELU);
BIAS_CASE(SIGMOID);
BIAS_CASE(H_SWISH);
IDENTITY_CASE(IDENTITY);
DEFAULT_CASE;
}
} else {
switch (nonlineMode) {
NOBIAS_CASE(RELU);
NOBIAS_CASE(SIGMOID);
NOBIAS_CASE(H_SWISH);
IDENTITY_CASE(IDENTITY);
DEFAULT_CASE;
}
}
FOR_BIAS(bias_mode);
}
};
template <typename ctype, typename dtype>
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
......@@ -297,6 +267,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::QUANTIZED> {
DType bias_type, DType dst_type, size_t N, size_t OC,
size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
megdnn_assert(pack_oc_size == 1,
"PostProcess only support nchw nchw in x86");
megdnn::param::Elemwise::Mode elem_mode =
megdnn::param::Elemwise::Mode::ADD;
if (bias_mode != megdnn::ConvBiasForward::BiasMode::NO_BIAS) {
......
......@@ -1297,6 +1297,32 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F32) {
#endif
}
#if MEGDNN_AARCH64
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false);
check_conv_bias(args, handle(), "CONV1x1:AARCH64_F32_MK4_K8X12X1:24");
}
#endif
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_MK4_NO_PACK_F32) {
using namespace conv_bias;
std::vector<conv_bias::TestArg> args =
get_nchw44_conv_bias_args({1}, 1, true, false, false);
std::vector<conv_bias::TestArg> args_of_4;
for (auto&& arg : args) {
if (arg.src.shape[2] * arg.src.shape[3] % 4 == 0) {
args_of_4.push_back(arg);
}
}
#if MEGDNN_AARCH64
check_conv_bias(args_of_4, handle(), "CONV1x1:AARCH64_F32_MK4_4x16:24");
#elif MEGDNN_ARMV7
check_conv_bias(args_of_4, handle(), "CONV1x1:ARMV7_F32_MK4_4x8:48");
#endif
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_F16) {
using namespace conv_bias;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册