提交 3de1fa5b 编写于 作者: M Megvii Engine Team

feat(arm/dnn): add support for nchw_nchw44 filter 2

GitOrigin-RevId: 013242911ec2e3acdf2e67676e3f4a821a9a5eb0
上级 f3547242
...@@ -207,7 +207,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable( ...@@ -207,7 +207,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable(
(fm.format == param::Convolution::Format::NCHW44); (fm.format == param::Convolution::Format::NCHW44);
bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1;
bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] &&
(fh == 3 || fh == 5 || fh == 7); (fh == 2 || fh == 3 || fh == 5 || fh == 7);
bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 &&
fm.stride[0] == 2 && fm.stride[1] == 2; fm.stride[0] == 2 && fm.stride[1] == 2;
bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS;
...@@ -267,6 +267,9 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns( ...@@ -267,6 +267,9 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns(
#define DISPATCH_CONV_KERN() \ #define DISPATCH_CONV_KERN() \
switch (param.filter_meta.spatial[0]) { \ switch (param.filter_meta.spatial[0]) { \
case 2: \
GET_BIAS_MODE_PARAM(2) \
break; \
case 3: \ case 3: \
GET_BIAS_MODE_PARAM(3) \ GET_BIAS_MODE_PARAM(3) \
break; \ break; \
......
...@@ -207,24 +207,27 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> { ...@@ -207,24 +207,27 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
float32x4_t src[src_reg_size]; float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size]; float32x4_t weight[c_dim][filter_size];
// row 0 // row 0
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, 0);
ld_weight_oc); load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight);
// row 1 // row 1
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight);
// row 2 // row 2
load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + 2 * iw, 0); load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( src, src_ptr + 2 * iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
...@@ -238,6 +241,52 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> { ...@@ -238,6 +241,52 @@ struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 3, oc_block> {
} }
}; };
template <BiasMode bias_mode, typename Op, int remain_w, int oc_block>
struct KerNeonXXs2NchwNchw44FP32<bias_mode, Op, remain_w, 2, oc_block> {
static void impl(const float32_t* src_ptr, const float32_t* weight_ptr,
const float32_t* bias_ptr, float32_t* dst_ptr, int ic,
int ih, int iw, int ld_dst_oc, const Op& op) {
constexpr int loop_ic_step = 1;
constexpr int filter_size = 2;
constexpr int oc_step = 4;
constexpr int simd_len = 4;
constexpr int src_reg_size = 4;
constexpr int ld_weight_fw = oc_step * filter_size;
const int ld_weight_oc = oc_step * filter_size * filter_size * ic;
const int ld_weight_ic = oc_step * filter_size * filter_size;
const int ld_src_ic = ih * iw;
constexpr int c_dim = OCHelper<oc_block>::val;
float32x4_t c[c_dim][8];
init_ocx_ow8<c_dim, bias_mode>(c, bias_ptr, oc_step);
for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) {
float32x4_t src[src_reg_size];
float32x4_t weight[c_dim][filter_size];
// row 0
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(src, src_ptr,
0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
// row 1
load_helper<src_reg_size, 0, simd_len, 0, Vld1q_f32>(
src, src_ptr + iw, 0);
load_helper<filter_size, 0, oc_step, c_dim, Vld1q_f32>(
weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc);
cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight);
cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight);
src_ptr += ld_src_ic;
weight_ptr += ld_weight_ic;
}
store_ocx_ow8_remain_static<c_dim, remain_w, Op>(c, op, dst_ptr,
ld_dst_oc);
}
};
} // namespace } // namespace
void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr,
...@@ -383,19 +432,12 @@ static void conv_direct_stride2_fp32_nchw_nchw44( ...@@ -383,19 +432,12 @@ static void conv_direct_stride2_fp32_nchw_nchw44(
ow, op, ph, pw); \ ow, op, ph, pw); \
} }
CONSTRUCT_FUNC(2);
CONSTRUCT_FUNC(3); CONSTRUCT_FUNC(3);
CONSTRUCT_FUNC(5); CONSTRUCT_FUNC(5);
CONSTRUCT_FUNC(7); CONSTRUCT_FUNC(7);
#undef CONSTRUCT_FUNC #undef CONSTRUCT_FUNC
template <BiasMode bias_mode, typename Op>
void conv_bias::conv_direct_stride2_2x2_fp32_nchw_nchw44(
const float32_t*, const float32_t*, const float32_t*, float32_t*,
float32_t*, const int, const int, const int, const int, const int,
const int, const int, const Op&, const int, const int) {
megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv");
}
#define INSTANTIATION(stride, i, bias, Op) \ #define INSTANTIATION(stride, i, bias, Op) \
template void conv_bias:: \ template void conv_bias:: \
conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \ conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44<bias, Op>( \
......
...@@ -195,6 +195,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { ...@@ -195,6 +195,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) {
}; };
if (is_fp32) { if (is_fp32) {
run(1, 1, 4, 112, 112, 2, 2, true);
run(1, 3, 32, 224, 224, 3, 2, true); run(1, 3, 32, 224, 224, 3, 2, true);
run(1, 3, 64, 224, 224, 7, 2, true); run(1, 3, 64, 224, 224, 7, 2, true);
} else { } else {
...@@ -1806,12 +1807,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { ...@@ -1806,12 +1807,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
auto opr = handle()->create_operator<ConvBias>(); auto opr = handle()->create_operator<ConvBias>();
opr->param() = arg.param; opr->param() = arg.param;
opr->deduce_layout({arg.src, dtype::Float32()}, opr->deduce_layout({arg.src, dtype::Float32()},
{arg.filter, dtype::Float32()}, {arg.filter, dtype::Float32()},
{arg.bias, dtype::Float32()}, {}, dst_layout); {arg.bias, dtype::Float32()}, {}, dst_layout);
//! dst.nr_elems * IC * FH * FW * 2 //! dst.nr_elems * IC * FH * FW * 2
float computations = dst_layout.total_nr_elems() * arg.filter[1] * float computations = dst_layout.total_nr_elems() * arg.filter[1] *
arg.filter[2] * arg.filter[3] * 2.0 / arg.filter[2] * arg.filter[3] * 2.0 /
(1024 * 1024 * 1024) * 1e3; (1024 * 1024 * 1024) * 1e3;
benchmark_im2col.set_param(arg.param); benchmark_im2col.set_param(arg.param);
auto im2col_used = auto im2col_used =
...@@ -1828,11 +1829,11 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { ...@@ -1828,11 +1829,11 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
RUN; RUN;
printf("%s %s: im2col: %f ms %f Gflops winograd: %f ms %f GFlops " printf("%s %s: im2col: %f ms %f Gflops winograd: %f ms %f GFlops "
"speedup: " "speedup: "
"%f\n", "%f\n",
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), arg.src.to_string().c_str(), arg.filter.to_string().c_str(),
im2col_used, computations / im2col_used, winograd_used, im2col_used, computations / im2col_used, winograd_used,
computations / winograd_used, im2col_used / winograd_used); computations / winograd_used, im2col_used / winograd_used);
} }
} }
......
...@@ -342,9 +342,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { ...@@ -342,9 +342,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) {
handle(), "F32STRD2_SMALL_GROUP"); handle(), "F32STRD2_SMALL_GROUP");
} }
TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) {
check_conv_bias( check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false,
get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), false, true),
handle(), "F32_CONV_NCHW_NCHW44"); handle(), "F32_CONV_NCHW_NCHW44");
} }
/**********************************F16 direct************************/ /**********************************F16 direct************************/
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册