提交 a1f8ecc7 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(dnn/naive): add convolution nchw44-dot format

GitOrigin-RevId: 87a7c9c5753ad2ced5a6914e50d5bf3c40908a5e
上级 73d84162
...@@ -35,7 +35,8 @@ pdef('Axis').add_fields('int32', 'axis', 0) ...@@ -35,7 +35,8 @@ pdef('Axis').add_fields('int32', 'axis', 0)
). ).
add_enum(Doc('Format', 'convolution data/filter/output format; see ' add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'), ':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW44', 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88',
'NCHW44','NCHW44_DOT',
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'),
......
...@@ -104,6 +104,7 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -104,6 +104,7 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
} else if (param().format == param::ConvBias::Format::NCHW4 || } else if (param().format == param::ConvBias::Format::NCHW4 ||
param().format == param::ConvBias::Format::NCHW44 || param().format == param::ConvBias::Format::NCHW44 ||
param().format == param::ConvBias::Format::NCHW44_DOT ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
......
...@@ -280,6 +280,13 @@ void make_canonized_filter_meta_nchwxx( ...@@ -280,6 +280,13 @@ void make_canonized_filter_meta_nchwxx(
/** /**
* input: N IC/pack_size, H, W, pack_size * input: N IC/pack_size, H, W, pack_size
* *
** NCHW44-DOT mode
* filter:
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(OC), pack_size(IC)}
* [dense]
* {GROUP, OC_PER_GROUP/pack_size, IC_PER_GROUP/pack_size, \
* FH, FW, pack_size(OC), pack_size(IC)} [group]
*
* NCHW88 and NCHW44 mode * NCHW88 and NCHW44 mode
* filter: * filter:
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)} * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
...@@ -300,6 +307,7 @@ void make_canonized_filter_meta_nchwxx( ...@@ -300,6 +307,7 @@ void make_canonized_filter_meta_nchwxx(
megdnn_assert(param.format == Param::Format::NCHW88 || megdnn_assert(param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 || param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW44_WINOGRAD || param.format == Param::Format::NCHW44_WINOGRAD ||
param.format == Param::Format::NCHW44_DOT ||
param.format == Param::Format::NCHW88_WINOGRAD); param.format == Param::Format::NCHW88_WINOGRAD);
size_t img_ndim = 2; size_t img_ndim = 2;
size_t flt_start = 0; size_t flt_start = 0;
...@@ -554,6 +562,7 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( ...@@ -554,6 +562,7 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter,
param(), ret); param(), ret);
} else if (param().format == Param::Format::NCHW44 || } else if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW44_WINOGRAD) { param().format == Param::Format::NCHW44_WINOGRAD) {
make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter,
param(), ret); param(), ret);
...@@ -660,6 +669,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -660,6 +669,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(param().format == Param::Format::NHWCD4 || megdnn_assert(param().format == Param::Format::NHWCD4 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW8 || param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW32 || param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW88 ||
...@@ -668,6 +678,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -668,6 +678,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
param().format == Param::Format::CHWN4); param().format == Param::Format::CHWN4);
img_dim = src.ndim - 3; img_dim = src.ndim - 3;
if ((param().format == Param::Format::NCHW88 || if ((param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW44) && param().format == Param::Format::NCHW44) &&
filter.ndim == 5) { filter.ndim == 5) {
img_dim = src.ndim - 2; img_dim = src.ndim - 2;
...@@ -675,6 +686,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -675,6 +686,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(filter.ndim == img_dim + 3 || megdnn_assert(filter.ndim == img_dim + 3 ||
(filter.ndim == img_dim + 2 && (filter.ndim == img_dim + 2 &&
(param().format == Param::Format::NCHW88 || (param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW44)) || param().format == Param::Format::NCHW44)) ||
filter.ndim == img_dim + 4 || filter.ndim == img_dim + 4 ||
filter.ndim == img_dim + 5, filter.ndim == img_dim + 5,
...@@ -727,6 +739,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -727,6 +739,7 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
src.to_string().c_str(), filter.to_string().c_str()); src.to_string().c_str(), filter.to_string().c_str());
} }
if (param().format == Param::Format::NCHW44 || if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW44_WINOGRAD) { param().format == Param::Format::NCHW44_WINOGRAD) {
megdnn_assert((src.ndim == 4 && filter.ndim == 5 && megdnn_assert((src.ndim == 4 && filter.ndim == 5 &&
filter[filter.ndim - 1] == 4) || filter[filter.ndim - 1] == 4) ||
...@@ -859,8 +872,9 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -859,8 +872,9 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
} }
} else if (param().format == Param::Format::NCHW44 || } else if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW44_WINOGRAD) { param().format == Param::Format::NCHW44_WINOGRAD) {
megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu", "invalid src ndim for NCHW44, expected=5 or 4, got=%zu",
src.ndim); src.ndim);
dst.ndim = 5; dst.ndim = 5;
......
...@@ -29,6 +29,7 @@ using namespace fallback; ...@@ -29,6 +29,7 @@ using namespace fallback;
size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) { size_t megdnn::fallback::get_format_pack_size(param::ConvBias::Format format) {
switch (format) { switch (format) {
case param::ConvBias::Format::NCHW44: case param::ConvBias::Format::NCHW44:
case param::ConvBias::Format::NCHW44_DOT:
case param::ConvBias::Format::NCHW4: case param::ConvBias::Format::NCHW4:
return 4_z; return 4_z;
case param::ConvBias::Format::NCHW88: case param::ConvBias::Format::NCHW88:
...@@ -188,6 +189,7 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( ...@@ -188,6 +189,7 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
param().format == Param::Format::NCHW8 || param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW || param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW_WINOGRAD || param().format == Param::Format::NCHW_WINOGRAD ||
param().format == Param::Format::NCHW88_WINOGRAD || param().format == Param::Format::NCHW88_WINOGRAD ||
...@@ -405,6 +407,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, ...@@ -405,6 +407,7 @@ const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id,
break; break;
} }
case Param::Format::NCHW44_DOT:
case Param::Format::NCHW44: { case Param::Format::NCHW44: {
size_t group = filter_meta.group; size_t group = filter_meta.group;
size_t icpg = filter_meta.icpg; size_t icpg = filter_meta.icpg;
......
...@@ -147,6 +147,7 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param( ...@@ -147,6 +147,7 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
if (param().format == Param::Format::NCHW88 || if (param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW8 || param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW44) { param().format == Param::Format::NCHW44) {
spatial_pos = 2; spatial_pos = 2;
} else if (param().format == Param::Format::NCHW || } else if (param().format == Param::Format::NCHW ||
......
...@@ -147,6 +147,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -147,6 +147,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
if (filter_meta.format == Format::NCHW || if (filter_meta.format == Format::NCHW ||
filter_meta.format == Format::NCHW88 || filter_meta.format == Format::NCHW88 ||
filter_meta.format == Format::NCHW44 || filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW44_DOT ||
filter_meta.format == Format::NCHW4 || filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW8 || filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW32) { filter_meta.format == Format::NCHW32) {
...@@ -174,6 +175,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -174,6 +175,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
if (filter_meta.format == Format::NCHW4 || if (filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::CHWN4 || filter_meta.format == Format::CHWN4 ||
filter_meta.format == Format::NCHW44_DOT ||
filter_meta.format == Format::NCHW44) { filter_meta.format == Format::NCHW44) {
OC *= 4; OC *= 4;
} else if (filter_meta.format == Format::NCHW8 || } else if (filter_meta.format == Format::NCHW8 ||
...@@ -219,7 +221,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -219,7 +221,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
FS_G = FS_OC * filter_meta.ocpg / 8; FS_G = FS_OC * filter_meta.ocpg / 8;
} }
} }
} else if (filter_meta.format == Format::NCHW44) { } else if (filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW44_DOT) {
if (filter_meta.group > 1 && filter_meta.icpg == 1 && if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
src.layout.ndim == 5 && filter_meta.ocpg == 1) { src.layout.ndim == 5 && filter_meta.ocpg == 1) {
FS_SPATIAL = 4; FS_SPATIAL = 4;
...@@ -282,7 +285,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -282,7 +285,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
h * layout.stride[2] + w * layout.stride[3] + h * layout.stride[2] + w * layout.stride[3] +
(c & 0b111) * layout.stride[4]; (c & 0b111) * layout.stride[4];
} }
} else if (filter_meta.format == Format::NCHW44) { } else if (filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW44_DOT) {
if (filter_meta.format == Format::NCHW44 && !is_output && if (filter_meta.format == Format::NCHW44 && !is_output &&
src.layout.ndim == 4) { src.layout.ndim == 4) {
return n * layout.stride[0] + c * layout.stride[1] + return n * layout.stride[0] + c * layout.stride[1] +
...@@ -327,30 +331,41 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -327,30 +331,41 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
((ic - ic0) % 4); ((ic - ic0) % 4);
} else if (filter_meta.format == Format::NCHW88) { } else if (filter_meta.format == Format::NCHW88 ||
filter_meta.format == Format::NCHW44) {
size_t pack_c_size = 4_z;
if(filter_meta.format == Format::NCHW88){
pack_c_size = 8_z;
}
if (src.layout.ndim == 4) { if (src.layout.ndim == 4) {
// ic < 8, input is nchw // ic < 8, input is nchw
return gc_out.cur_grp * FS_G + gc_out.cur_off / 8 * FS_OC + return gc_out.cur_grp * FS_G +
gc_out.cur_off / pack_c_size * FS_OC +
(fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
gc_out.cur_off % 8; gc_out.cur_off % pack_c_size;
} else if (filter_meta.group > 1 && filter_meta.icpg == 1 && } else if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
filter_meta.ocpg == 1 && src.layout.ndim == 5) { filter_meta.ocpg == 1 && src.layout.ndim == 5) {
// dw case // dw case
return gc_out.cur_grp / 8 * FS_G + gc_out.cur_off * FS_OC + return gc_out.cur_grp / pack_c_size * FS_G +
(ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL + gc_out.cur_off * FS_OC + (ic - ic0) * FS_IC +
gc_out.cur_grp % 8; (fh * FW + fw) * FS_SPATIAL +
gc_out.cur_grp % pack_c_size;
} else if (src.layout.ndim == 5) { } else if (src.layout.ndim == 5) {
// normal case // normal case
return gc_out.cur_grp * FS_G + gc_out.cur_off / 8 * FS_OC + return gc_out.cur_grp * FS_G +
(ic - ic0) / 8 * FS_IC + (fh * FW + fw) * FS_SPATIAL + gc_out.cur_off / pack_c_size * FS_OC +
((ic - ic0) & 0b111) * 8 + gc_out.cur_off % 8; (ic - ic0) / pack_c_size * FS_IC +
(fh * FW + fw) * FS_SPATIAL +
((ic - ic0) % pack_c_size) * pack_c_size +
gc_out.cur_off % pack_c_size;
} else { } else {
megdnn_assert( megdnn_throw(
0, "nchw88 naive not support this input and output\n"); "nchw88/nchw44 naive not support this input and "
"output\n");
} }
} else if (filter_meta.format == Format::NCHW44) { } else if (filter_meta.format == Format::NCHW44_DOT) {
if (src.layout.ndim == 4) { if (src.layout.ndim == 4) {
// ic < 8, input is nchw // ic < 4, input is nchw
return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
(fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
gc_out.cur_off % 4; gc_out.cur_off % 4;
...@@ -364,10 +379,10 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -364,10 +379,10 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
// normal case // normal case
return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
(ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
((ic - ic0) % 4) * 4 + gc_out.cur_off % 4; (gc_out.cur_off % 4) * 4 + ((ic - ic0) % 4);
} else { } else {
megdnn_assert( megdnn_throw(
0, "nchw44 naive not support this input and output\n"); "nchw44_dot naive not support this input and output\n");
} }
} else { } else {
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
...@@ -559,6 +574,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, ...@@ -559,6 +574,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
filter_meta.format == param::Convolution::Format::NHWC || filter_meta.format == param::Convolution::Format::NHWC ||
filter_meta.format == param::Convolution::Format::NCHW88 || filter_meta.format == param::Convolution::Format::NCHW88 ||
filter_meta.format == param::Convolution::Format::NCHW44 || filter_meta.format == param::Convolution::Format::NCHW44 ||
filter_meta.format == param::Convolution::Format::NCHW44_DOT ||
filter_meta.format == param::Convolution::Format::NCHW4); filter_meta.format == param::Convolution::Format::NCHW4);
compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( compute2d<stype, ftype, dtype, comp_type, StrategyFwd>(
src, const_cast<ftype*>(fptr), dst, filter_meta); src, const_cast<ftype*>(fptr), dst, filter_meta);
...@@ -613,6 +629,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -613,6 +629,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
case param::Convolution::Format::NCHW: case param::Convolution::Format::NCHW:
case param::Convolution::Format::NCHW88: case param::Convolution::Format::NCHW88:
case param::Convolution::Format::NCHW44: case param::Convolution::Format::NCHW44:
case param::Convolution::Format::NCHW44_DOT:
case param::Convolution::Format::NHWC: case param::Convolution::Format::NHWC:
case param::Convolution::Format::NCHW4: case param::Convolution::Format::NCHW4:
case param::Convolution::Format::NCHW8: case param::Convolution::Format::NCHW8:
...@@ -690,6 +707,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -690,6 +707,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
} \ } \
} while (0) } while (0)
case Format::NCHW44: case Format::NCHW44:
case Format::NCHW44_DOT:
case Format::NCHW4: { case Format::NCHW4: {
BIAS_ADD_NCHWx(4); BIAS_ADD_NCHWx(4);
break; break;
......
...@@ -350,9 +350,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) { ...@@ -350,9 +350,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_SMALL_GROUP) {
get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false), get_conv_bias_args({1, 2, 3, 4, 5, 6, 7}, 1, false, false, false),
handle(), "F32DIRECT_SMALL_GROUP"); handle(), "F32DIRECT_SMALL_GROUP");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_1) {
check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 1, false, false, check_conv_bias(get_nchw44_conv_bias_args({2, 7}, 1, false, false, false,
false, false, true, true), false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT");
}
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_NCHW44_S1_2) {
check_conv_bias(get_nchw44_conv_bias_args({3, 5}, 1, false, false, false,
false, true, true),
handle(), "F32_CONV_NCHW44_DIRECT"); handle(), "F32_CONV_NCHW44_DIRECT");
} }
......
...@@ -516,4 +516,177 @@ TEST_F(NAIVE, CONV_BIAS_NCHW44) { ...@@ -516,4 +516,177 @@ TEST_F(NAIVE, CONV_BIAS_NCHW44) {
224, 268, 311, 218, 288, 311, 346, 277})}); 224, 268, 311, 218, 288, 311, 346, 277})});
} }
} }
TEST_F(NAIVE, CONV_BIAS_NCHW44_DOT) {
Checker<ConvBias> checker(handle(), /* check_dispatch */ false);
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44_DOT;
size_t n = 1;
size_t ic = 4;
size_t oc = 8;
size_t h = 2;
size_t w = 2;
size_t filter_size = 3;
size_t pad = 1;
auto src_tensor_shape = TensorShape{n, ic / 4, h, w, 4};
auto weight_tensor_shape =
TensorShape{oc / 4, ic / 4, filter_size, filter_size, 4, 4};
auto bias_tensor_shape = TensorShape{1, oc / 4, 1, 1, 4};
param.pad_h = pad;
param.pad_w = pad;
UniformIntRNG rng{-127, 127};
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_epsilon(1e-3)
.set_param(param)
.execs({src_tensor_shape,
weight_tensor_shape,
bias_tensor_shape,
{},
{}});
checker.set_dtype(0, dtype::QuantizedS8(2.f))
.set_dtype(1, dtype::QuantizedS8(3.f))
.set_dtype(2, dtype::QuantizedS32(6.f))
.set_dtype(4, dtype::QuantizedS32(6.f))
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_epsilon(1e-3)
.set_param(param)
.execs({src_tensor_shape,
weight_tensor_shape,
bias_tensor_shape,
{},
{}});
{
// test normal conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44_DOT;
param.sparse = ConvBias::Param::Sparse::DENSE;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7,
6, 4}),
TensorValue(
{1, 1, 3, 3, 4, 4}, dtype::Float32(),
{3, 0, 3, 1, 5, 1, 5, 7, 5, 4, 0, 0, 2, 8, 7,
7, 6, 5, 7, 3, 4, 2, 6, 2, 7, 2, 6, 2, 7, 4,
3, 8, 5, 0, 0, 7, 0, 5, 4, 7, 4, 1, 8, 2, 4,
0, 4, 0, 4, 6, 0, 1, 8, 2, 6, 4, 7, 3, 4, 3,
3, 0, 4, 8, 8, 2, 3, 7, 8, 5, 2, 0, 7, 5, 8,
2, 2, 1, 1, 7, 1, 0, 2, 4, 6, 6, 4, 2, 1, 3,
1, 7, 5, 0, 1, 5, 7, 5, 3, 0, 8, 7, 2, 1, 4,
0, 8, 4, 5, 3, 6, 6, 6, 2, 1, 5, 6, 4, 7, 2,
0, 4, 8, 8, 1, 1, 2, 3, 8, 6, 3, 1, 3, 3, 7,
1, 5, 4, 2, 1, 0, 3, 8, 4}),
TensorValue({1, 1, 1, 1, 4}, dtype::Float32(),
{7, 2, 8, 1}),
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0}),
{}},
Testcase{
{},
{},
{},
{},
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{264, 338, 309, 195, 276, 332, 390, 199,
224, 268, 311, 218, 288, 311, 346, 277})});
}
{
// test dw conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44_DOT;
param.sparse = ConvBias::Param::Sparse::GROUP;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{5, 8, 3, 2, 4, 6, 1, 5, 0, 8, 2, 6, 8, 6,
5, 7}),
TensorValue({1, 1, 1, 3, 3, 4}, dtype::Float32(),
{3, 0, 3, 1, 6, 5, 7, 3, 5, 0, 0, 7,
4, 6, 0, 1, 8, 2, 3, 7, 1, 0, 2, 4,
7, 5, 3, 0, 6, 2, 1, 5, 8, 6, 3, 1}),
TensorValue({1, 1, 1, 1, 4}, dtype::Float32(),
{4, 3, 5, 6}),
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0}),
{}},
Testcase{{},
{},
{},
{},
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{112, 71, 33, 77, 104, 115, 19, 78, 62, 59,
42, 117, 107, 93, 36, 78})});
}
{
// test group conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44_DOT;
param.sparse = ConvBias::Param::Sparse::GROUP;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
{6, 3, 2, 7, 7, 6, 4, 5, 8, 6, 3,
1, 1, 2, 8, 3, 1, 0, 6, 1, 3, 3,
6, 0, 0, 5, 6, 7, 2, 2, 4, 4}),
TensorValue(
{2, 1, 1, 3, 3, 4, 4}, dtype::Float32(),
{3, 0, 3, 1, 5, 1, 5, 7, 5, 4, 0, 0, 2, 8, 7,
7, 6, 5, 7, 3, 4, 2, 6, 2, 7, 2, 6, 2, 7, 4,
3, 8, 5, 0, 0, 7, 0, 5, 4, 7, 4, 1, 8, 2, 4,
0, 4, 0, 4, 6, 0, 1, 8, 2, 6, 4, 7, 3, 4, 3,
3, 0, 4, 8, 8, 2, 3, 7, 8, 5, 2, 0, 7, 5, 8,
2, 2, 1, 1, 7, 1, 0, 2, 4, 6, 6, 4, 2, 1, 3,
1, 7, 5, 0, 1, 5, 7, 5, 3, 0, 8, 7, 2, 1, 4,
0, 8, 4, 5, 3, 6, 6, 6, 2, 1, 5, 6, 4, 7, 2,
0, 4, 8, 8, 1, 1, 2, 3, 8, 6, 3, 1, 3, 3, 7,
1, 5, 4, 2, 1, 0, 3, 8, 4, 7, 6, 8, 3, 4, 8,
1, 0, 5, 7, 3, 0, 0, 4, 5, 3, 7, 8, 1, 3, 7,
1, 1, 0, 7, 2, 2, 0, 3, 0, 1, 1, 1, 6, 4, 0,
3, 3, 1, 2, 0, 0, 4, 1, 5, 5, 7, 6, 7, 1, 3,
5, 8, 6, 2, 1, 0, 7, 7, 1, 2, 6, 6, 1, 2, 3,
1, 2, 4, 8, 3, 2, 6, 0, 7, 4, 3, 7, 3, 3, 5,
3, 0, 3, 5, 1, 4, 5, 6, 2, 0, 5, 3, 3, 3, 5,
2, 4, 7, 1, 3, 5, 2, 8, 1, 8, 1, 2, 5, 1, 0,
6, 7, 7, 8, 7, 8, 8, 1, 8, 4, 4, 1, 4, 4, 5,
0, 2, 2, 2, 0, 1, 8, 4, 4, 7, 6, 8, 0, 1, 5,
4, 2, 6}),
TensorValue({1, 2, 1, 1, 4}, dtype::Float32(),
{1, 8, 5, 6, 2, 8, 7, 7}),
TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
{}},
Testcase{
{},
{},
{},
{},
TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
{260, 342, 244, 241, 293, 385, 362, 257,
278, 301, 303, 226, 273, 306, 318, 307,
180, 244, 169, 156, 210, 244, 206, 167,
126, 165, 156, 207, 191, 141, 209, 172})});
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册