提交 48ac1e1a 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn/fallback): delete nopack onlypacka noneed datatype,and add

im2co and conv1x1 mk4_dot support

GitOrigin-RevId: 096b16a3abb8fa259db77e792bd286dacf3fd8c3
上级 3117bfb7
...@@ -913,10 +913,10 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, ...@@ -913,10 +913,10 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
*outptr++ = *inptr++; *outptr++ = *inptr++;
} }
for (; i < 4; i++) { for (; i < 4; i++) {
*outptr++ = *inptr++; *outptr++ = 0;
*outptr++ = *inptr++; *outptr++ = 0;
*outptr++ = *inptr++; *outptr++ = 0;
*outptr++ = *inptr++; *outptr++ = 0;
} }
} }
......
...@@ -39,7 +39,7 @@ namespace { ...@@ -39,7 +39,7 @@ namespace {
megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \ megdnn::arm_common::OpCallerUnary<_op<ctype>, megdnn::arm_common::VEC>:: \
run(static_cast<ctype*>(conv_dst_ptr), \ run(static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \ reinterpret_cast<ctype*>(dst_ptr), bias_type, dst_type, \
N* OC* OH* OW); N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common:: \ megdnn::arm_common:: \
...@@ -63,7 +63,7 @@ namespace { ...@@ -63,7 +63,7 @@ namespace {
static_cast<ctype*>(conv_dst_ptr), \ static_cast<ctype*>(conv_dst_ptr), \
reinterpret_cast<const ctype*>(bias_ptr), \ reinterpret_cast<const ctype*>(bias_ptr), \
reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<ctype*>(dst_ptr), bias_type, bias_type, \
dst_type, N* OC* OH* OW); dst_type, N* OC* OH* OW* pack_oc_size);
#define FOR_BIAS(_mode) \ #define FOR_BIAS(_mode) \
switch (_mode) { \ switch (_mode) { \
...@@ -113,7 +113,6 @@ struct PostProcess { ...@@ -113,7 +113,6 @@ struct PostProcess {
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N, megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
MEGDNN_MARK_USED_VAR(pack_oc_size);
FOR_BIAS(bias_mode) FOR_BIAS(bias_mode)
} }
}; };
...@@ -155,7 +154,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { ...@@ -155,7 +154,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
_op<opctype, opdtype>, \ _op<opctype, opdtype>, \
megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \ megdnn::arm_common::VEC>::run(static_cast<opctype*>(conv_dst_ptr), \
reinterpret_cast<opdtype*>(dst_ptr), \ reinterpret_cast<opdtype*>(dst_ptr), \
bias_type, dst_type, N* OC* OH* OW); bias_type, dst_type, \
N* OC* OH* OW* pack_oc_size);
#define FOR_NONLINEAR_BINARY_BROADCAST(_op) \ #define FOR_NONLINEAR_BINARY_BROADCAST(_op) \
megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \ megdnn::arm_common::OpCallerBinary<_op<opctype, opdtype>, \
...@@ -173,8 +173,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { ...@@ -173,8 +173,8 @@ struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \ reinterpret_cast<opdtype*>(dst_ptr), bias_type, bias_type, \
dst_type, N, OC, OH* OW, pack_oc_size); dst_type, N, OC, OH* OW, pack_oc_size);
#define HANDLE_IDENTITY(_caller, _op) \ #define HANDLE_IDENTITY(_caller, _op) \
case megdnn::NonlineMode::IDENTITY: \ case megdnn::NonlineMode::IDENTITY: \
_caller(_op) break; _caller(_op) break;
#define FOR_NONLINEAR(_caller) \ #define FOR_NONLINEAR(_caller) \
......
...@@ -729,10 +729,10 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin, ...@@ -729,10 +729,10 @@ static void gemm_dots8_8x6_pack_B(dt_int8* out, const dt_int8* in, int ldin,
*outptr++ = *inptr++; *outptr++ = *inptr++;
} }
for (; i < 4; i++) { for (; i < 4; i++) {
*outptr++ = *inptr++; *outptr++ = 0;
*outptr++ = *inptr++; *outptr++ = 0;
*outptr++ = *inptr++; *outptr++ = 0;
*outptr++ = *inptr++; *outptr++ = 0;
} }
} }
outptr_base += 24; outptr_base += 24;
......
...@@ -187,7 +187,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, ...@@ -187,7 +187,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
AlgoSelectionStrategy) const { AlgoSelectionStrategy) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) { MIDOUT_BEGIN(megdnn_fallback_conv1x1, 0, 2) {
if (opr->param().format != param::ConvBias::Format::NCHW && if (opr->param().format != param::ConvBias::Format::NCHW &&
opr->param().format != param::ConvBias::Format::NCHW44) opr->param().format != param::ConvBias::Format::NCHW44 &&
opr->param().format != param::ConvBias::Format::NCHW44_DOT)
return false; return false;
size_t FH = param.filter_meta.spatial[0], size_t FH = param.filter_meta.spatial[0],
...@@ -219,8 +220,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr, ...@@ -219,8 +220,8 @@ bool ConvBiasImpl::AlgoConv1x1::usable(ConvBiasImpl* opr,
param.nonlineMode != megdnn::NonlineMode::IDENTITY) param.nonlineMode != megdnn::NonlineMode::IDENTITY)
return false; return false;
if (opr->param().format == param::ConvBias::Format::NCHW44) { if (opr->param().format == param::ConvBias::Format::NCHW44 ||
//! nchw44 hybird mode and channel wise is not support opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 || if (param.filter_meta.icpg < 4_z || param.filter_meta.icpg == 1 ||
param.filter_meta.ocpg == 1) { param.filter_meta.ocpg == 1) {
return false; return false;
......
...@@ -73,32 +73,34 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy( ...@@ -73,32 +73,34 @@ std::unique_ptr<Conv1x1StrategyBase> create_conv1x1_strategy(
const ConvBiasImpl::NCBKernSizeParam& param, const ConvBiasImpl::NCBKernSizeParam& param,
MatrixMulImpl::AlgoBase::PackMode pack_mode, MatrixMulImpl::AlgoBase::PackMode pack_mode,
param::ConvBias::Format format) { param::ConvBias::Format format) {
size_t pack_size = get_format_pack_size(format); size_t pack_c_size = pack_size(format);
#define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \ #define cb1(_packmode, _dt, _post_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \ midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \
return std::make_unique< \ return std::make_unique< \
Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \ Conv1x1Strategy<_dt, _dt, _dt, _post_ctype, _post_ctype, \
_postprocess_mode, _packmode>>(pack_size); \ _postprocess_mode, _packmode>>( \
} \ pack_c_size); \
} \ } \
} \
MIDOUT_END() MIDOUT_END()
#define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \ #define cb2(_packmode, _i_src_type, _i_bias_type, _i_dst_type, _src_ctype, \
_bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \ _bias_ctype, _dst_ctype, _postprocess_mode, _midout_tag) \
MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \ MIDOUT_BEGIN(megdnn_fallback_conv1x1_factory_strategy, \
midout_iv(_midout_tag)) { \ midout_iv(_midout_tag)) { \
if (param.filter_type.enumv() == param.src_type.enumv() && \ if (param.filter_type.enumv() == param.src_type.enumv() && \
param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \ param.src_type.enumv() == DTypeTrait<_i_src_type>::enumv && \
param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \ param.dst_type.enumv() == DTypeTrait<_i_dst_type>::enumv) { \
return std::make_unique< \ return std::make_unique< \
Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \ Conv1x1Strategy<_src_ctype, _bias_ctype, _dst_ctype, \
DTypeTrait<_i_bias_type>::ctype, \ DTypeTrait<_i_bias_type>::ctype, \
DTypeTrait<_i_dst_type>::ctype, \ DTypeTrait<_i_dst_type>::ctype, \
_postprocess_mode, _packmode>>(pack_size); \ _postprocess_mode, _packmode>>( \
} \ pack_c_size); \
} \ } \
} \
MIDOUT_END() MIDOUT_END()
switch (pack_mode) { switch (pack_mode) {
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
#pragma once #pragma once
#include "megdnn/opr_param_defs.h"
#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"
#if MEGDNN_X86 #if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h" #include "src/x86/conv_bias/postprocess_helper.h"
...@@ -41,12 +40,15 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param( ...@@ -41,12 +40,15 @@ MatrixMulImpl::KernSizeParam get_matmul_kern_param(
param.dst_type.enumv() == DTypeEnum::QuantizedS8) || param.dst_type.enumv() == DTypeEnum::QuantizedS8) ||
(param.src_type.enumv() == DTypeEnum::Quantized8Asymm && (param.src_type.enumv() == DTypeEnum::Quantized8Asymm &&
param.dst_type.enumv() == DTypeEnum::Quantized8Asymm); param.dst_type.enumv() == DTypeEnum::Quantized8Asymm);
size_t pack_c_size = 1_z; size_t pack_c_size = pack_size(param.filter_meta.format);
auto format = param::MatrixMul::Format::DEFAULT; auto format = param::MatrixMul::Format::DEFAULT;
if(param.filter_meta.format == param::ConvBias::Format::NCHW44){ if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
pack_c_size = 4_z;
format = param::MatrixMul::Format::MK4; format = param::MatrixMul::Format::MK4;
} else if (param.filter_meta.format ==
param::ConvBias::Format::NCHW44_DOT) {
format = param::MatrixMul::Format::MK4_DOT;
} }
return {param.filter_type, return {param.filter_type,
param.src_type, param.src_type,
is_dst_8bit ? param.bias_type : param.dst_type, is_dst_8bit ? param.bias_type : param.dst_type,
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"
#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"
#include "src/fallback/conv_bias/winograd/strategy.h"
#include "src/naive/convolution/helper.h" #include "src/naive/convolution/helper.h"
#include "midout.h" #include "midout.h"
...@@ -125,7 +124,7 @@ public: ...@@ -125,7 +124,7 @@ public:
size_t oc_tile_size) { size_t oc_tile_size) {
size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0], size_t IC = param.filter_meta.icpg, FH = param.filter_meta.spatial[0],
FW = param.filter_meta.spatial[1]; FW = param.filter_meta.spatial[1];
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); size_t pack_oc_size = pack_size(param.filter_meta.format);
size_t im2col = 0, packb = 0, bias_temp = 0; size_t im2col = 0, packb = 0, bias_temp = 0;
bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT; bool default_pack = matmul_algo->packmode() == Pack_Mode::DEFAULT;
megdnn_assert(default_pack, "only support default packa"); megdnn_assert(default_pack, "only support default packa");
...@@ -319,9 +318,11 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, ...@@ -319,9 +318,11 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
size_t ohw_tile_size, size_t ohw_tile_size,
size_t oc_tile_size) const { size_t oc_tile_size) const {
auto format = param::MatrixMul::Format::DEFAULT; auto format = param::MatrixMul::Format::DEFAULT;
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); size_t pack_oc_size = pack_size(param.filter_meta.format);
if (param.filter_meta.format == param::ConvBias::Format::NCHW44) { if (param.filter_meta.format == param::ConvBias::Format::NCHW44) {
format = param::MatrixMul::Format::MK4; format = param::MatrixMul::Format::MK4;
} else if(param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT){
format = param::MatrixMul::Format::MK4_DOT;
} }
size_t M = oc_tile_size; size_t M = oc_tile_size;
size_t N = ohw_tile_size; size_t N = ohw_tile_size;
...@@ -351,11 +352,10 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param, ...@@ -351,11 +352,10 @@ ConvBiasImpl::AlgoIm2col ::get_matmul_kern_param(const NCBKernSizeParam& param,
void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
const NCBKernSizeParam& param, size_t& oc_tile_size, const NCBKernSizeParam& param, size_t& oc_tile_size,
size_t& ohw_tile_size, size_t block_m, size_t block_n, size_t& ohw_tile_size, size_t block_m, size_t block_n,
bool need_pack) const { fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const {
size_t nr_threads = param.nr_threads; size_t nr_threads = param.nr_threads;
size_t OC = param.filter_meta.ocpg; size_t OC = param.filter_meta.ocpg;
size_t ohw = param.osz[0] * param.osz[1]; size_t ohw = param.osz[0] * param.osz[1];
oc_tile_size = DEFAULT_OC_TILE_SIZE; oc_tile_size = DEFAULT_OC_TILE_SIZE;
ohw_tile_size = m_ohw_tile_size; ohw_tile_size = m_ohw_tile_size;
...@@ -376,7 +376,8 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block( ...@@ -376,7 +376,8 @@ void ConvBiasImpl::AlgoIm2col::choice_ohw_oc_block(
} }
} }
} else { } else {
if (!need_pack) { //! no pack ,usually in x86 save memroy //! in no_pack mode don't do block operation when using single thread
if (pack_mode == fallback::MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
ohw_tile_size = ohw; ohw_tile_size = ohw;
oc_tile_size = OC; oc_tile_size = OC;
} }
...@@ -406,7 +407,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( ...@@ -406,7 +407,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
if (need_pack || only_packA) { if (need_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size(); auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m,
inner_block.n, need_pack); inner_block.n, m_matmul_algo->packmode());
auto im2col_kern_param = get_matmul_kern_param( auto im2col_kern_param = get_matmul_kern_param(
param, ohw_tile_size, only_packA ? oc_tile_size : OC); param, ohw_tile_size, only_packA ? oc_tile_size : OC);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
...@@ -418,7 +419,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( ...@@ -418,7 +419,7 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle(
size_t nopack_default_blockn = 16; size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn, nopack_default_blockm, nopack_default_blockn,
need_pack); m_matmul_algo->packmode());
packa_group_size = 0; packa_group_size = 0;
} }
...@@ -488,19 +489,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( ...@@ -488,19 +489,20 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns(
if (default_pack || only_packA) { if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size(); auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack); inner_block.m, inner_block.n,
} else { //! not support pack,not need pack m_matmul_algo->packmode());
} else { //! nopack_mode
size_t nopack_default_blockm = 8; size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16; size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn, nopack_default_blockm, nopack_default_blockn,
no_pack); m_matmul_algo->packmode());
} }
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size);
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
size_t packa_parallel_times = 0; size_t packa_parallel_times = 0;
size_t pack_oc_size = get_format_pack_size(param.filter_meta.format); size_t pack_oc_size = pack_size(param.filter_meta.format);
if (only_packA) { if (only_packA) {
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size);
...@@ -639,9 +641,15 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -639,9 +641,15 @@ bool ConvBiasImpl::AlgoIm2col::usable(
ConvBiasImpl* opr, const NCBKernSizeParam& param, ConvBiasImpl* opr, const NCBKernSizeParam& param,
AlgoSelectionStrategy /*algo_selection_strategy*/) const { AlgoSelectionStrategy /*algo_selection_strategy*/) const {
MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) { MIDOUT_BEGIN(megdnn_fallback_im2col, 0, 2) {
if (opr->param().format != param::ConvBias::Format::NCHW &&
opr->param().format != param::ConvBias::Format::NCHW44_DOT &&
opr->param().format != param::ConvBias::Format::NCHW44) {
return false;
}
//! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is //! make sure 8x8x16 and 8x8x32 biasmode is nobias and nonlineMode is
//! identity otherwise return false mean that 8x8x32 and 8x8x16 not support //! identity otherwise return false mean that 8x8x32 and 8x8x16 not
//! PostProcess //! support PostProcess
if (param.src_type.enumv() == param.filter_type.enumv() && if (param.src_type.enumv() == param.filter_type.enumv() &&
((param.src_type.enumv() == DTypeEnum::Int8 && ((param.src_type.enumv() == DTypeEnum::Int8 &&
(param.dst_type.enumv() == DTypeEnum::Int16 || (param.dst_type.enumv() == DTypeEnum::Int16 ||
...@@ -653,9 +661,10 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -653,9 +661,10 @@ bool ConvBiasImpl::AlgoIm2col::usable(
param.nonlineMode != megdnn::NonlineMode::IDENTITY) { param.nonlineMode != megdnn::NonlineMode::IDENTITY) {
return false; return false;
} }
if (opr->param().format == param::ConvBias::Format::NCHW44) { if (opr->param().format == param::ConvBias::Format::NCHW44 ||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) {
//! current NCHW44 im2col only support DEFAULT mode matmul //! current NCHW44 im2col only support DEFAULT mode matmul
if(m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) {
return false; return false;
//! nchw44 hybird mode and channel wise is not support //! nchw44 hybird mode and channel wise is not support
} else if (param.filter_meta.icpg < 4_z || } else if (param.filter_meta.icpg < 4_z ||
...@@ -668,29 +677,27 @@ bool ConvBiasImpl::AlgoIm2col::usable( ...@@ -668,29 +677,27 @@ bool ConvBiasImpl::AlgoIm2col::usable(
size_t oc_tile_size = 0, ohw_tile_size = 0; size_t oc_tile_size = 0, ohw_tile_size = 0;
Pack_Mode packmode = m_matmul_algo->packmode(); Pack_Mode packmode = m_matmul_algo->packmode();
bool default_pack = packmode == Pack_Mode::DEFAULT; bool default_pack = packmode == Pack_Mode::DEFAULT;
bool no_pack = packmode == Pack_Mode::NO_PACK;
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; bool only_packA = packmode == Pack_Mode::ONLY_PACKA;
if (default_pack || only_packA) { if (default_pack || only_packA) {
auto inner_block = m_matmul_algo->get_inner_block_size(); auto inner_block = m_matmul_algo->get_inner_block_size();
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
inner_block.m, inner_block.n, default_pack); inner_block.m, inner_block.n,
m_matmul_algo->packmode());
} else { //! not support pack,not need pack } else { //! not support pack,not need pack
size_t nopack_default_blockm = 8; size_t nopack_default_blockm = 8;
size_t nopack_default_blockn = 16; size_t nopack_default_blockn = 16;
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size,
nopack_default_blockm, nopack_default_blockn, nopack_default_blockm, nopack_default_blockn,
no_pack); m_matmul_algo->packmode());
} }
fallback::MatrixMulImpl::KernSizeParam matmul_param = fallback::MatrixMulImpl::KernSizeParam matmul_param =
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); get_matmul_kern_param(param, ohw_tile_size, oc_tile_size);
bool matmulusable = m_matmul_algo->usable(matmul_param); bool matmulusable = m_matmul_algo->usable(matmul_param);
return matmulusable && return matmulusable &&
(opr->param().format == param::ConvBias::Format::NCHW ||
opr->param().format == param::ConvBias::Format::NCHW44) &&
(!(param.filter_meta.spatial[0] == (!(param.filter_meta.spatial[0] ==
param.filter_meta.spatial[1] && param.filter_meta.spatial[1] &&
(param.filter_meta.spatial[0] == 1) && param.filter_meta.spatial[0] == 1 &&
param.filter_meta.stride[0] == param.filter_meta.stride[1] && param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
param.filter_meta.stride[0] == 1)) && param.filter_meta.stride[0] == 1)) &&
(param.filter_meta.dilation[0] == (param.filter_meta.dilation[0] ==
......
...@@ -36,10 +36,10 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase { ...@@ -36,10 +36,10 @@ class ConvBiasImpl::AlgoIm2col final : public AlgoBase {
const NCBKernSizeParam& param, size_t ohw_tile_size, const NCBKernSizeParam& param, size_t ohw_tile_size,
size_t oc_tile_size) const; size_t oc_tile_size) const;
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
void choice_ohw_oc_block(const NCBKernSizeParam& param, void choice_ohw_oc_block(
size_t& oc_tile_size, size_t& ohw_tile_size, const NCBKernSizeParam& param, size_t& oc_tile_size,
size_t block_m, size_t block_n, size_t& ohw_tile_size, size_t block_m, size_t block_n,
bool pack_default) const; fallback::MatrixMulImpl::AlgoBase::PackMode pack_mode) const;
public: public:
AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size) AlgoIm2col(MatrixMulImpl::AlgoBase* matmul_algo, size_t ohw_tile_size)
......
...@@ -230,7 +230,11 @@ public: ...@@ -230,7 +230,11 @@ public:
PostprocessMode::FLOAT, PostprocessMode::FLOAT,
"DefaultStrategyTypeNCHW44::FLOAT"_hash); "DefaultStrategyTypeNCHW44::FLOAT"_hash);
} else { } else {
megdnn_throw("not support format except nchw44 and nchw\n"); megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW for im2col "
"algo, but got %d\n",
uint32_t(format)));
} }
break; break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
...@@ -252,12 +256,17 @@ public: ...@@ -252,12 +256,17 @@ public:
cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, cb2(NCHW, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash); "DefaultStrategyType::INT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44) { } else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8, cb2(NCHW44, DEFAULT, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyType::INT8x8x32"_hash); "DefaultStrategyType::INT8x8x32"_hash);
} else { } else {
megdnn_throw("not support format except nchw44 and nchw\n"); megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
} }
break; break;
...@@ -288,13 +297,18 @@ public: ...@@ -288,13 +297,18 @@ public:
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32, dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS, PostprocessMode::NO_PROCESS,
"DefaultStrategyTypeNCHW::QINT8x8x32"_hash); "DefaultStrategyTypeNCHW::QINT8x8x32"_hash);
} else if (format == param::ConvBias::Format::NCHW44) { } else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS32, dt_int8, dtype::QuantizedS32, dtype::QuantizedS32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS, dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"DefaultStrategyTypeHCHW44::QINT8x8x32"_hash); "DefaultStrategyTypeHCHW44::QINT8x8x32"_hash);
} else { } else {
megdnn_throw("not support format except nchw44 and nchw\n"); megdnn_throw(
ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
} }
break; break;
...@@ -304,17 +318,22 @@ public: ...@@ -304,17 +318,22 @@ public:
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED, PostprocessMode::QUANTIZED,
"DefaultStrategyType::QINT8x8x32x8"_hash); "DefaultStrategyType::QINT8x8x32x8"_hash);
} else if (format == param::ConvBias::Format::NCHW44) { } else if (format == param::ConvBias::Format::NCHW44 ||
format == param::ConvBias::Format::NCHW44_DOT) {
cb2(NCHW44, DEFAULT, dtype::QuantizedS8, cb2(NCHW44, DEFAULT, dtype::QuantizedS8,
dtype::QuantizedS32, dtype::QuantizedS8, dt_int8, dtype::QuantizedS32, dtype::QuantizedS8, dt_int8,
dt_int32, dt_int8, PostprocessMode::QUANTIZED, dt_int32, dt_int8, PostprocessMode::QUANTIZED,
"DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash); "DefaultStrategyTypeNCHW44::QINT8x8x32x8"_hash);
} else { } else {
megdnn_throw("not support format except nchw44 and nchw\n"); megdnn_throw(ssprintf("Current only support layout "
"NCHW44/NCHW/NCHW_DOT for im2col "
"algo, but got %d\n",
uint32_t(format)));
} }
break; break;
} }
megdnn_throw("error not support strategy type "); megdnn_throw(ssprintf("Unsupported strategy type %u in default mode",
uint32_t(strategytype)));
} }
static std::unique_ptr<StrategyBase> make_nopack_strategy( static std::unique_ptr<StrategyBase> make_nopack_strategy(
...@@ -328,10 +347,6 @@ public: ...@@ -328,10 +347,6 @@ public:
PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash); PostprocessMode::FLOAT, "NoPackStrategyType::FLOAT"_hash);
break; break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case StrategyType::FLOAT_FP16:
cb1(NCHW, NO_PACK, dt_float16, __fp16, PostprocessMode::FLOAT,
"NoPackStrategyType::FLOAT_FP16"_hash);
break;
#else #else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16: case StrategyType::FLOAT16_FLOAT16:
...@@ -341,48 +356,24 @@ public: ...@@ -341,48 +356,24 @@ public:
break; break;
#endif #endif
#endif #endif
case StrategyType::INT8x8x32:
cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x32"_hash);
break;
case StrategyType::INT8x8x16: case StrategyType::INT8x8x16:
cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8, cb2(NCHW, NO_PACK, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS, dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"NoPackStrategyType::INT8x8x16"_hash); "NoPackStrategyType::INT8x8x16"_hash);
break; break;
case StrategyType::INT8x8x32:
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 cb2(NCHW, NO_PACK, dt_int8, dt_int32, dt_int32, dt_int8,
case StrategyType::QUINT8x8x32: dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32, "NoPackStrategyType::INT8x8x32"_hash);
dtype::QuantizedS32, dt_uint8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::QUINT8x8x32"_hash);
break;
case StrategyType::QUINT8x8x32x8:
cb2(NCHW, NO_PACK, dtype::Quantized8Asymm, dtype::QuantizedS32,
dtype::Quantized8Asymm, dt_uint8, dt_int32, dt_uint8,
PostprocessMode::QUANTIZED,
"NoPackStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32:
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"NoPackStrategyType::QINT8x8x32"_hash);
break; break;
default:
case StrategyType::QINT8x8x32x8: megdnn_throw(
cb2(NCHW, NO_PACK, dtype::QuantizedS8, dtype::QuantizedS32, ssprintf("Unsupported strategy type %u in no_pack mode",
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8, uint32_t(strategytype)));
PostprocessMode::QUANTIZED,
"NoPackStrategyType::QINT8x8x32x8"_hash);
break; break;
} }
megdnn_throw("error not support strategy type "); megdnn_throw(ssprintf("Unsupported strategy type %u in no_pack mode",
uint32_t(strategytype)));
} }
static std::unique_ptr<StrategyBase> make_onlypacka_strategy( static std::unique_ptr<StrategyBase> make_onlypacka_strategy(
...@@ -396,63 +387,14 @@ public: ...@@ -396,63 +387,14 @@ public:
PostprocessMode::FLOAT, PostprocessMode::FLOAT,
"OnlyPackaStrategyType::FLOAT"_hash); "OnlyPackaStrategyType::FLOAT"_hash);
break; break;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC default:
case StrategyType::FLOAT_FP16: megdnn_throw(ssprintf(
cb1(NCHW, ONLY_PACKA, dt_float16, __fp16, "Unsupported strategy type %u in onlypacka mode",
PostprocessMode::FLOAT, uint32_t(strategytype)));
"OnlyPackaStrategyType::FLOAT_FP16"_hash);
break;
#else
#if !MEGDNN_DISABLE_FLOAT16
case StrategyType::FLOAT16_FLOAT16:
cb1(NCHW, ONLY_PACKA, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::FLOAT16_FLOAT16"_hash);
break;
#endif
#endif
case StrategyType::INT8x8x32:
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int32, dt_int32, dt_int8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::INT8x8x32"_hash);
break;
case StrategyType::INT8x8x16:
cb2(NCHW, ONLY_PACKA, dt_int8, dt_int16, dt_int16, dt_int8,
dt_int16, dt_int16, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::INT8x8x16"_hash);
break;
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
case StrategyType::QUINT8x8x32:
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::QuantizedS32, dt_uint8,
dt_int32, dt_int32, PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::QUINT8x8x32"_hash);
break;
case StrategyType::QUINT8x8x32x8:
cb2(NCHW, ONLY_PACKA, dtype::Quantized8Asymm,
dtype::QuantizedS32, dtype::Quantized8Asymm, dt_uint8,
dt_int32, dt_uint8, PostprocessMode::QUANTIZED,
"OnlyPackaStrategyType::QUINT8x8x32x8"_hash);
break;
#endif
case StrategyType::QINT8x8x32:
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS32, dt_int8, dt_int32, dt_int32,
PostprocessMode::NO_PROCESS,
"OnlyPackaStrategyType::QINT8x8x32"_hash);
break;
case StrategyType::QINT8x8x32x8:
cb2(NCHW, ONLY_PACKA, dtype::QuantizedS8, dtype::QuantizedS32,
dtype::QuantizedS8, dt_int8, dt_int32, dt_int8,
PostprocessMode::QUANTIZED,
"OnlyPackaStrategyType::QINT8x8x32x8"_hash);
break; break;
} }
megdnn_throw("error not support strategy type "); megdnn_throw(ssprintf("Unsupported strategy type %u in onlypacka mode",
uint32_t(strategytype)));
} }
#undef cb1 #undef cb1
......
...@@ -11,6 +11,16 @@ ...@@ -11,6 +11,16 @@
#pragma once #pragma once
#include "src/fallback/conv_bias/opr_impl.h" #include "src/fallback/conv_bias/opr_impl.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn { namespace megdnn {
using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode; using PackMode = fallback::MatrixMulImpl::AlgoBase::PackMode;
...@@ -72,6 +82,185 @@ public: ...@@ -72,6 +82,185 @@ public:
const StrategyParam& sparam, WorkspaceBundle bundle_thread) = 0; const StrategyParam& sparam, WorkspaceBundle bundle_thread) = 0;
}; };
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode,
FormatMode format>
//! this class is a new base class for StrategyDefault StrategyNoPack and so on,
//! in order to handle copy pad use the same code
class StrategyBridge : public StrategyBase {
public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0;
StrategyBridge() = default;
virtual void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) override {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t PH_SIZE = PH * IW2 * pack_oc_size;
PW = PW * pack_oc_size;
IW = IW * pack_oc_size;
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
batch_id, group_id, channel_id, 1, pack_oc_size));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
}
};
namespace{
template <typename bias_ctype>
inline void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam,
size_t matmul_bundle_index) {
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
return static_cast<void*>(bundle_thread.get(matmul_bundle_index));
} else {
bias_ctype* dst =
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw;
return static_cast<void*>(dst);
}
}
template <typename bias_ctype>
inline void* get_bias_temp_ptr(
const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, size_t bias_bundle_index) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(bias_bundle_index))
: nullptr;
return bias_tmp_ptr;
}
template <typename dst_ctype>
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
size_t pack_oc_size = sparam.pack_oc_size;
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
for (size_t oc = 0; oc < oc_loop; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size *
pack_oc_size);
dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
dst += sparam.ohw * pack_oc_size;
}
}
}
template <typename bias_ctype>
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam,
size_t bias_index) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr = static_cast<bias_ctype*>(
get_bias_temp_ptr<bias_ctype>(param, bundle_thread, bias_index));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
size_t pack_oc_size = sparam.pack_oc_size;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
for (size_t oc = sparam.oc_cur_index / pack_oc_size;
oc < sparam.oc_end_index / pack_oc_size; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size *
pack_oc_size);
copy_dst += sparam.output_block_size * pack_oc_size;
copy_src += sparam.ohw * pack_oc_size;
}
}
}
template <typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void do_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, WorkspaceBundle bundle_thread,
size_t matmul_bundle_index, size_t bias_bundle_index) {
copy_bias<bias_ctype>(param, bundle_thread, sparam, bias_bundle_index);
void* matmul_dst = get_matmul_dst_ptr<bias_ctype>(
param, bundle_thread, sparam, matmul_bundle_index);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
void* bias_temp_ptr = get_bias_temp_ptr<bias_ctype>(param, bundle_thread,
bias_bundle_index);
void* bias_preprocess_ptr = const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index)));
size_t pack_oc_size = sparam.pack_oc_size;
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size / pack_oc_size, 1_z,
sparam.output_block_size, pack_oc_size);
copy_dst<dst_ctype>(param, matmul_dst, sparam);
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode, PackMode packmode, megdnn::PostprocessMode postprocess_mode, PackMode packmode,
...@@ -82,7 +271,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -82,7 +271,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT> : public StrategyBase { postprocess_mode, PackMode::DEFAULT>
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype,
op_dtype, postprocess_mode, PackMode::DEFAULT,
FormatMode::NCHW> {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
...@@ -92,13 +284,7 @@ public: ...@@ -92,13 +284,7 @@ public:
Strategy() = default; Strategy() = default;
void copy_padding_kern( virtual void packA_kern(WorkspaceBundle bundle,
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
fallback::MatrixMulImpl::AlgoBase* matmul_algo, fallback::MatrixMulImpl::AlgoBase* matmul_algo,
...@@ -120,16 +306,13 @@ public: ...@@ -120,16 +306,13 @@ public:
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override; WorkspaceBundle bundle_thread) override {
do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype,
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, postprocess_mode>(param, sparam, bundle_thread,
const void* matmul_dst, const StrategyParam& sparam); THREAD_BUNDLE_IM2COL_INDEX,
THREAD_BUNDLE_BIAS_INDEX);
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, }
WorkspaceBundle bundle_thread, const StrategyParam& sparam);
void* get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread);
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam); const StrategyParam& sparam);
...@@ -162,7 +345,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype, ...@@ -162,7 +345,10 @@ template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK> : public StrategyBase { postprocess_mode, PackMode::NO_PACK>
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype,
op_dtype, postprocess_mode, PackMode::NO_PACK,
FormatMode::NCHW> {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
...@@ -173,12 +359,6 @@ public: ...@@ -173,12 +359,6 @@ public:
Strategy() = default; Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle, void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
...@@ -198,17 +378,6 @@ public: ...@@ -198,17 +378,6 @@ public:
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam); const StrategyParam& sparam);
inline void* get_bias_temp_ptr(
const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
: nullptr;
return bias_tmp_ptr;
}
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread,
const StrategyParam& sparam, const StrategyParam& sparam,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
...@@ -216,19 +385,22 @@ public: ...@@ -216,19 +385,22 @@ public:
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; fallback::MatrixMulImpl::AlgoBase* matmul_algo) override;
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override; WorkspaceBundle bundle_thread) override {
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype,
const void* matmul_dst, const StrategyParam& sparam); postprocess_mode>(param, sparam, bundle_thread,
THREAD_BUNDLE_MATMULDST_INDEX,
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, THREAD_BUNDLE_BIAS_INDEX);
WorkspaceBundle bundle_thread, const StrategyParam& sparam); }
}; };
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, class Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA> : public StrategyBase { postprocess_mode, PackMode::ONLY_PACKA>
: public StrategyBridge<src_ctype, bias_ctype, dst_ctype, op_ctype,
op_dtype, postprocess_mode,
PackMode::ONLY_PACKA,FormatMode::NCHW> {
public: public:
constexpr static size_t BUNDLE_PADDING_INDEX = 0; constexpr static size_t BUNDLE_PADDING_INDEX = 0;
constexpr static size_t BUNDLE_PACKA_INDEX = 1; constexpr static size_t BUNDLE_PACKA_INDEX = 1;
...@@ -239,12 +411,6 @@ public: ...@@ -239,12 +411,6 @@ public:
Strategy() = default; Strategy() = default;
void copy_padding_kern(
WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_size) override;
void packA_kern(WorkspaceBundle bundle, void packA_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param, const fallback::ConvBiasImpl::NCBKernParam& param,
fallback::MatrixMulImpl::KernSizeParam matmulparam, fallback::MatrixMulImpl::KernSizeParam matmulparam,
...@@ -269,24 +435,15 @@ public: ...@@ -269,24 +435,15 @@ public:
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread, const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam); const StrategyParam& sparam);
inline void* get_bias_temp_ptr(
const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
: nullptr;
return bias_tmp_ptr;
}
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const StrategyParam& sparam,
WorkspaceBundle bundle_thread) override; WorkspaceBundle bundle_thread) override {
void copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param, do_postprocess<bias_ctype, dst_ctype, op_ctype, op_dtype,
const void* matmul_dst, const StrategyParam& sparam); postprocess_mode>(param, sparam, bundle_thread,
THREAD_BUNDLE_MATMULDST_INDEX,
void copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param, THREAD_BUNDLE_BIAS_INDEX);
WorkspaceBundle bundle_thread, const StrategyParam& sparam); }
}; };
} // namespace megdnn } // namespace megdnn
......
...@@ -10,85 +10,9 @@ ...@@ -10,85 +10,9 @@
*/ */
#include "src/fallback/conv_bias/im2col/strategy_base.h" #include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn { namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t pack_oc_size) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t PH_SIZE = PH * IW2 * pack_oc_size;
PW = PW * pack_oc_size;
IW = IW * pack_oc_size;
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = pack_oc_size * IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(param.src<src_ctype>(
batch_id, group_id, channel_id, 1, pack_oc_size));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH_SIZE);
src2_ptr += PH_SIZE;
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
...@@ -244,100 +168,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -244,100 +168,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_kern_naked(matmul_param, a_panel, b_panel); matmul_kern_naked(matmul_param, a_panel, b_panel);
} }
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
copy_bias(param, bundle_thread, sparam);
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
void* bias_temp_ptr = get_bias_temp_ptr(param, bundle_thread);
void* bias_preprocess_ptr = const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index)));
size_t pack_oc_size = sparam.pack_oc_size;
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst, bias_preprocess_ptr, matmul_dst, param.bias_mode,
param.nonlineMode, param.bias_type, param.dst_type, 1_z,
sparam.output_block_oc_size / pack_oc_size, 1_z,
sparam.output_block_size, pack_oc_size);
copy_dst(param, matmul_dst, sparam);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
size_t pack_oc_size = sparam.pack_oc_size;
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index * pack_oc_size;
size_t oc_loop = sparam.output_block_oc_size / pack_oc_size;
for (size_t oc = 0; oc < oc_loop; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size *
pack_oc_size);
dst_tmp_ptr += sparam.output_block_size * pack_oc_size;
dst += sparam.ohw * pack_oc_size;
}
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
get_bias_temp_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread) {
bias_ctype* bias_tmp_ptr =
param.bias_mode == megdnn::BiasMode::BIAS
? static_cast<bias_ctype*>(
bundle_thread.get(THREAD_BUNDLE_BIAS_INDEX))
: nullptr;
return bias_tmp_ptr;
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::DEFAULT>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index;
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size);
copy_dst += sparam.output_block_size;
copy_src += sparam.ohw;
}
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \ _op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
......
...@@ -11,81 +11,9 @@ ...@@ -11,81 +11,9 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h" #include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn { namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
...@@ -220,81 +148,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -220,81 +148,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
} }
} }
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam,
WorkspaceBundle bundle_thread) {
copy_bias(param, bundle_thread, sparam);
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam);
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst,
const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index))),
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type,
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z,
sparam.output_block_size);
copy_dst(param, matmul_dst, sparam);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index;
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size);
dst_tmp_ptr += sparam.output_block_size;
dst += sparam.ohw;
}
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::NO_PACK>::
copy_bias(const fallback::ConvBiasImpl::NCBKernParam& param,
WorkspaceBundle bundle_thread, const StrategyParam& sparam) {
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>(
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id));
bias_ctype* bias_temp_ptr =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread));
if (param.bias_mode == megdnn::BiasMode::BIAS) {
bias_ctype* copy_dst = bias_temp_ptr;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index;
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size);
copy_dst += sparam.output_block_size;
copy_src += sparam.ohw;
}
}
}
#define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ #define INSTANTIAL_CLASS(_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
_op_dtype, _postprocess_mode) \ _op_dtype, _postprocess_mode) \
template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \ template class Strategy<_src_ctype, _bias_ctype, _dst_ctype, _op_ctype, \
...@@ -302,34 +155,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -302,34 +155,18 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else #else
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16, INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS) megdnn::PostprocessMode::NO_PROCESS)
#endif #endif
#endif #endif
#undef INSTANTIAL_CLASS
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -11,81 +11,9 @@ ...@@ -11,81 +11,9 @@
#include "src/fallback/conv_bias/im2col/strategy_base.h" #include "src/fallback/conv_bias/im2col/strategy_base.h"
#include "src/fallback/convolution/img2col_helper.h" #include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86
using namespace x86;
#endif
namespace megdnn { namespace megdnn {
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
copy_padding_kern(WorkspaceBundle bundle,
const fallback::ConvBiasImpl::NCBKernParam& param,
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index,
size_t) {
UNPACK_CONV_F32_NCB_KERN_SIZES(param);
MEGDNN_MARK_USED_VAR(N);
MEGDNN_MARK_USED_VAR(OC);
MEGDNN_MARK_USED_VAR(OH);
MEGDNN_MARK_USED_VAR(OW);
MEGDNN_MARK_USED_VAR(FH);
MEGDNN_MARK_USED_VAR(FW);
MEGDNN_MARK_USED_VAR(SH);
MEGDNN_MARK_USED_VAR(SW);
size_t IW2 = IW + 2 * PW;
size_t IH2 = IH + 2 * PH;
size_t batch_id = ncb_index.ndrange_id[0];
size_t group_id = ncb_index.ndrange_id[1];
size_t channel_id = ncb_index.ndrange_id[2];
size_t padding_group_size = IH2 * IW2 * IC;
size_t workspace_channel_offset = IH2 * IW2 * channel_id;
size_t workspace_group_offset = group_id * padding_group_size;
size_t workspace_batch_offset =
param.filter_meta.group * batch_id * padding_group_size;
bundle.set(param.workspace_ptr);
src_ctype src_zp = static_cast<src_ctype>(0);
if (param.src_type.enumv() == DTypeEnum::Quantized8Asymm) {
src_zp = param.src_type.param<dtype::Quantized8Asymm>().zero_point;
}
src_ctype* src = const_cast<src_ctype*>(
param.src<src_ctype>(batch_id, group_id, channel_id));
src_ctype* src2;
src2 = static_cast<src_ctype*>(bundle.get(BUNDLE_PADDING_INDEX)) +
workspace_group_offset + workspace_batch_offset +
workspace_channel_offset;
src_ctype* src2_ptr = src2;
const src_ctype* src_ptr = src;
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
rep(ih, IH) {
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
std::memcpy(src2_ptr, src_ptr, sizeof(src_ctype) * IW);
src2_ptr += IW;
src_ptr += IW;
if (PW != 0)
rep(pw, PW) * (src2_ptr++) = src_zp;
}
if (PH != 0) {
std::memset(src2_ptr, src_zp, sizeof(src_ctype) * PH * IW2);
src2_ptr += PH * IW2;
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
...@@ -120,25 +48,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -120,25 +48,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z); matmul_algo->pack_A(matmul_param, a_panel, 0_z, 0_z);
} }
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const WorkspaceBundle& bundle_thread,
const StrategyParam& sparam) {
if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
return static_cast<void*>(
bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX));
} else {
bias_ctype* dst =
param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw;
return static_cast<void*>(dst);
}
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
...@@ -241,63 +150,19 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -241,63 +150,19 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
template <typename src_ctype, typename bias_ctype, typename dst_ctype, template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype, typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode> megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, void* Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>:: postprocess_mode, PackMode::ONLY_PACKA>::
exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param,
const StrategyParam& sparam, const WorkspaceBundle& bundle_thread,
WorkspaceBundle bundle_thread) { const StrategyParam& sparam) {
void* matmul_dst = get_matmul_dst_ptr(param, bundle_thread, sparam); if (sparam.is_dst_8bit || !sparam.is_ohw_size_bigger) {
return static_cast<bias_ctype*>(
const bias_ctype* bias_ptr = static_cast<const bias_ctype*>( bundle_thread.get(THREAD_BUNDLE_MATMULDST_INDEX));
param.bias<bias_ctype>(sparam.batch_id, sparam.group_id)); } else {
bias_ctype* bias_temp_ptr = bias_ctype* dst =
static_cast<bias_ctype*>(get_bias_temp_ptr(param, bundle_thread)); param.dst<bias_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw;
if (param.bias_mode == megdnn::BiasMode::BIAS) { return static_cast<void*>(dst);
bias_ctype* copy_dst = bias_temp_ptr;
const bias_ctype* copy_src = bias_ptr +
sparam.oc_cur_index * sparam.ohw +
sparam.ohw_cur_index;
for (size_t oc = sparam.oc_cur_index; oc < sparam.oc_end_index; oc++) {
std::memcpy(copy_dst, copy_src,
sizeof(bias_ctype) * sparam.output_block_size);
copy_dst += sparam.output_block_size;
copy_src += sparam.ohw;
}
}
PostProcess<op_ctype, op_dtype, postprocess_mode>::run(
matmul_dst,
const_cast<void*>(
param.bias_mode == megdnn::BiasMode::BIAS
? bias_temp_ptr
: static_cast<void*>(const_cast<bias_ctype*>(
bias_ptr + sparam.oc_cur_index))),
matmul_dst, param.bias_mode, param.nonlineMode, param.bias_type,
param.dst_type, 1_z, sparam.output_block_oc_size, 1_z,
sparam.output_block_size);
copy_dst(param, matmul_dst, sparam);
}
template <typename src_ctype, typename bias_ctype, typename dst_ctype,
typename op_ctype, typename op_dtype,
megdnn::PostprocessMode postprocess_mode>
void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
postprocess_mode, PackMode::ONLY_PACKA>::
copy_dst(const fallback::ConvBiasImpl::NCBKernParam& param,
const void* matmul_dst, const StrategyParam& sparam) {
if (!sparam.skip_copy_dst) {
dst_ctype* dst_tmp_ptr =
reinterpret_cast<dst_ctype*>(const_cast<void*>(matmul_dst));
dst_ctype* dst =
param.dst<dst_ctype>(sparam.batch_id, sparam.group_id) +
sparam.oc_cur_index * sparam.ohw + sparam.ohw_cur_index;
for (size_t oc = 0; oc < sparam.output_block_oc_size; oc++) {
std::memcpy(dst, dst_tmp_ptr,
sizeof(dst_ctype) * sparam.output_block_size);
dst_tmp_ptr += sparam.output_block_size;
dst += sparam.ohw;
}
} }
} }
...@@ -310,33 +175,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, ...@@ -310,33 +175,6 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype,
INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32, INSTANTIAL_CLASS(dt_float32, dt_float32, dt_float32, dt_float32, dt_float32,
megdnn::PostprocessMode::FLOAT) megdnn::PostprocessMode::FLOAT)
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, __fp16, __fp16,
megdnn::PostprocessMode::FLOAT)
#else
#if !MEGDNN_DISABLE_FLOAT16
INSTANTIAL_CLASS(dt_float16, dt_float16, dt_float16, dt_float16, dt_float16,
megdnn::PostprocessMode::NO_PROCESS)
#endif
#endif
#if MEGDNN_AARCH64 || MEGDNN_ARMV7
//! x86 do not have uint8 matmul so only armv7 armv8 support uint8
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_uint8, dt_qint32, dt_quint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_uint8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#endif
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int8, dt_qint32, dt_qint8,
megdnn::PostprocessMode::QUANTIZED)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_int32, dt_int32,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int16, dt_int16, dt_int16, dt_int16,
megdnn::PostprocessMode::NO_PROCESS)
INSTANTIAL_CLASS(dt_int8, dt_int32, dt_int32, dt_qint32, dt_qint32,
megdnn::PostprocessMode::NO_PROCESS)
#undef INSTANTIAL_CLASS #undef INSTANTIAL_CLASS
} // namespace megdnn } // namespace megdnn
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
using namespace megdnn; using namespace megdnn;
using namespace fallback; using namespace fallback;
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { size_t megdnn::fallback::pack_size(param::ConvBias::Format format) {
switch (format) { switch (format) {
case param::ConvBias::Format::NCHW44: case param::ConvBias::Format::NCHW44:
case param::ConvBias::Format::NCHW44_DOT: case param::ConvBias::Format::NCHW44_DOT:
......
...@@ -23,8 +23,10 @@ namespace fallback { ...@@ -23,8 +23,10 @@ namespace fallback {
/*! /*!
* \brief get the pack_size according to the format * \brief get the pack_size according to the format
* Note TODO: when remove format from param,
* may using like this "opr::param::format specify"
* */ * */
size_t get_format_pack_size(param::ConvBias::Format format); size_t pack_size(param::ConvBias::Format format);
/*! /*!
* \brief fallback conv bias forward impl * \brief fallback conv bias forward impl
......
...@@ -52,9 +52,21 @@ class GemmInterleaved<Strategy, true> { ...@@ -52,9 +52,21 @@ class GemmInterleaved<Strategy, true> {
} }
size_t get_b_workspace_size() const { size_t get_b_workspace_size() const {
#if __ARM_FEATURE_DOTPROD
size_t new_blockn = m_strategy.block_n;
if (m_strategy.KERNEL_W == 6 && m_strategy.UNROLL_K == 4 &&
m_strategy.KERNEL_H == 8) {
new_blockn = round_up<size_t>((m_strategy.block_n-1) % 6, 4) +
m_strategy.block_n / 6 * 6;
}
size_t N = round_up(new_blockn, m_strategy.KERNEL_W);
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K);
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size;
#else
size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W); size_t N = round_up(m_strategy.block_n, m_strategy.KERNEL_W);
size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K); size_t K = round_up(m_strategy.block_k, m_strategy.UNROLL_K);
return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size; return round_up(sizeof(stype) * N * K, CACHELINE_SIZE) + m_align_size;
#endif
} }
//! temporary storage for output, post process such as add bias or relu will //! temporary storage for output, post process such as add bias or relu will
......
...@@ -268,7 +268,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) { ...@@ -268,7 +268,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, BENCHMARK_CONVBIAS_NCHW44) {
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:AARCH64_F32K8X12X1:192", false); "IM2COLMATMUL:AARCH64_F32K8X12X1:192", false);
#else #else
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", benchmark_convbias(handle(), "IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8:384",
"IM2COLMATMUL:ARMV7_F32:192", true); "IM2COLMATMUL:ARMV7_F32:192", true);
benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384", benchmark_convbias(handle(), "IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16:384",
"IM2COLMATMUL:ARMV7_F32:192", false); "IM2COLMATMUL:ARMV7_F32:192", false);
......
...@@ -72,10 +72,12 @@ std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args( ...@@ -72,10 +72,12 @@ std::vector<conv_bias::TestArg> get_int8_quint8_conv_bias_args(
std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false, std::vector<size_t> kernel_vec, size_t stride, bool no_pad = false,
bool no_bias = false, bool no_nonlinemode = false, bool no_bias = false, bool no_nonlinemode = false,
bool is_input_nchw = false, bool support_full_bias = false, bool is_input_nchw = false, bool is_nchw44_dot = false,
bool support_sigmoid = false) { bool support_full_bias = false, bool support_sigmoid = false,
bool only_no_bias = false) {
using namespace conv_bias; using namespace conv_bias;
using NLMode = param::ConvBias::NonlineMode; using NLMode = param::ConvBias::NonlineMode;
std::vector<TestArg> args; std::vector<TestArg> args;
auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w, auto pack = [&](size_t n, size_t oc, size_t ic, size_t h, size_t w,
...@@ -102,7 +104,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( ...@@ -102,7 +104,11 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
size_t kernel_h = kernel; size_t kernel_h = kernel;
size_t kernel_w = kernel; size_t kernel_w = kernel;
param::ConvBias param; param::ConvBias param;
param.format = param::ConvBias::Format::NCHW44; if (!is_nchw44_dot) {
param.format = param::ConvBias::Format::NCHW44;
} else {
param.format = param::ConvBias::Format::NCHW44_DOT;
}
param.stride_h = stride; param.stride_h = stride;
param.stride_w = stride; param.stride_w = stride;
param.pad_h = pad; param.pad_h = pad;
...@@ -155,18 +161,22 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args( ...@@ -155,18 +161,22 @@ std::vector<conv_bias::TestArg> get_nchw44_conv_bias_args(
if (support_sigmoid) { if (support_sigmoid) {
nonlinemode.emplace_back(NLMode::SIGMOID); nonlinemode.emplace_back(NLMode::SIGMOID);
} }
std::vector<megdnn::BiasMode> bias_mode = { std::vector<megdnn::BiasMode> bias_mode;
megdnn::BiasMode::BROADCAST_CHANNEL_BIAS}; if (!only_no_bias) {
if (no_bias) { bias_mode.emplace_back(megdnn::BiasMode::BROADCAST_CHANNEL_BIAS);
if (no_bias) {
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
}
} else {
bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS); bias_mode.emplace_back(megdnn::BiasMode::NO_BIAS);
} }
if (support_full_bias) { if (support_full_bias) {
bias_mode.emplace_back(megdnn::BiasMode::BIAS); bias_mode.emplace_back(megdnn::BiasMode::BIAS);
} }
for (auto bias : bias_mode) for (auto bias : bias_mode)
for (auto nlmode : nonlinemode) for (auto nlmode : nonlinemode)
for (size_t n : {1, 2}) for (size_t n : {1,2})
for (size_t kernel : kernel_vec) for (size_t kernel : kernel_vec)
for (size_t oc : {4, 12}) for (size_t oc : {4, 12})
for (size_t ic : {1, 3, 4, 12}) for (size_t ic : {1, 3, 4, 12})
...@@ -361,19 +371,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) { ...@@ -361,19 +371,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K7) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K2K3) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false, check_conv_bias(get_nchw44_conv_bias_args({2, 3}, 1, false, false, false,
false, true, true), false, false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_K5) {
check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false, check_conv_bias(get_nchw44_conv_bias_args({5}, 1, false, false, false,
false, true, true), false, false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S2) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
false, false, true, true), false, false, false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }
...@@ -1420,6 +1430,111 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { ...@@ -1420,6 +1430,111 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
#endif #endif
#undef cb #undef cb
} }
#if __ARM_FEATURE_DOTPROD
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias(get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, \
false, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_S8x8x32_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, false, true, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({2, 3, 4, 5, 6, 7}, 1, false, false, \
true, false, true, false, false, true), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 2, false, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X6X4_DOTPROD:96");
#endif
#undef cb
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
UniformIntRNG rng{-50, 50};
#define cb(name) \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 1, true, true, false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::QuantizedS8(2.5f), \
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), {}, name); \
checker_conv_bias( \
get_nchw44_conv_bias_args({1}, 1, true, true, true, false, true, \
false, false, true), \
handle(), &rng, epsilon, dtype::Int8(), dtype::Int8(), \
dtype::Int32(), {}, name);
float epsilon = 0.001;
#if MEGDNN_AARCH64
cb("CONV1x1:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD");
#elif MEGDNN_ARMV7
cb("CONV1x1:AARCH32_INT8_MK4_8X6X4_DOTPROD");
#endif
#undef cb
}
#endif
// clang-format on // clang-format on
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 #if MEGDNN_AARCH64 || MEGDNN_ARMV7
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
...@@ -1685,8 +1800,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { ...@@ -1685,8 +1800,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
get_nchw44_conv_bias_args({2, 4, 7}, 1); {2, 4, 7}, 1, false, false, false, false, false, true,true);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
...@@ -1696,8 +1811,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) { ...@@ -1696,8 +1811,8 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S1_MK4_PACK_F32) {
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COL_S2_MK4_PACK_F32) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
get_nchw44_conv_bias_args({3, 5, 6}, 2); {3, 5, 6}, 2, false, false, false, false, false, true, true);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1"); check_conv_bias(args, handle(), "IM2COLMATMUL:AARCH64_F32_MK4_K8X12X1");
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
......
...@@ -897,6 +897,62 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) { ...@@ -897,6 +897,62 @@ TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32) {
#undef cb #undef cb
} }
#if MEGDNN_X86_WITH_MKL || MEGDNN_X86_WITH_OPENBLAS
TEST_F(X86, CONV_BIAS_IM2COLMATMUL_FP32) {
using namespace conv_bias;
std::vector<TestArg> args;
auto run = [&](size_t oc, size_t ic, size_t w, size_t h, size_t kernel,
size_t p, NonlineMode nonline_mode) {
if (w + 2 * p < kernel || h + 2 * p < kernel)
return;
param::ConvBias param;
param.stride_h = 1;
param.stride_w = 1;
param.pad_h = p;
param.pad_w = p;
param.nonlineMode = nonline_mode;
//! no bias
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel}, TensorShape{});
args.emplace_back(param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, 1, 1});
args.emplace_back(
param, TensorShape{1, ic, h, w},
TensorShape{oc, ic, kernel, kernel},
TensorShape{1, oc, (h + 2 * p - kernel) / param.stride_h + 1,
(w + 2 * p - kernel) / param.stride_w + 1});
};
for (size_t kernel : {2, 3, 4, 5, 6, 7})
for (size_t ic : {1, 4, 8, 16})
for (size_t oc : {1, 4, 8, 16, 300})
for (size_t p : {0, 2})
for (size_t size : {8, 24})
for (NonlineMode nonline_mode :
{NonlineMode::IDENTITY, NonlineMode::RELU}) {
run(oc, ic, size, size, kernel, p, nonline_mode);
}
run(2046, 8, 20, 20, 3, 1, NonlineMode::IDENTITY);
Checker<ConvBias> checker(handle());
#define cb(algo_name) \
checker.set_before_exec_callback( \
conv_bias::ConvBiasAlgoChecker<ConvBias>(algo_name)); \
for (auto&& arg : args) { \
checker.set_param(arg.param).execs( \
{arg.src, arg.filter, arg.bias, {}, {}}); \
}
cb("IM2COLMATMUL:X86_F32_BLAS");
#undef cb
}
#endif
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) { TEST_F(X86_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP32_PACKA) {
using namespace conv_bias; using namespace conv_bias;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册