diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp index 8499fef2bc3a877a4af27dfaa850c705c3933e09..f6635cd01bb4031df05b26ead49fcaafc8f91875 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_algo.cpp @@ -207,7 +207,7 @@ bool ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::usable( (fm.format == param::Convolution::Format::NCHW44); bool ok_src_dst = fm.icpg < 4 && (oc % 4 == 0 && oc >= 4) && fm.group == 1; bool ok_filter = fm.spatial_ndim == 2 && fh == fm.spatial[1] && - (fh == 3 || fh == 5 || fh == 7); + (fh == 2 || fh == 3 || fh == 5 || fh == 7); bool ok_slide = fm.dilation[0] == 1 && fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2; bool ok_conv = !fm.should_flip && param.bias_mode != BiasMode::BIAS; @@ -267,6 +267,9 @@ ConvBiasImpl::AlgoF32DirectStride2NCHWNCHW44::dispatch_kerns( #define DISPATCH_CONV_KERN() \ switch (param.filter_meta.spatial[0]) { \ + case 2: \ + GET_BIAS_MODE_PARAM(2) \ + break; \ case 3: \ GET_BIAS_MODE_PARAM(3) \ break; \ diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp index fc9aca9f94a759af896e677e599b66a7cde5a88d..ba3b62721472864046a9a8f5b5196d2b8d146e4e 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_stride2_nchw_nchw44_kern.cpp @@ -207,24 +207,27 @@ struct KerNeonXXs2NchwNchw44FP32 { float32x4_t src[src_reg_size]; float32x4_t weight[c_dim][filter_size]; // row 0 - load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr, 0); - load_helper<3, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, - ld_weight_oc); + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); // row 1 - load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + iw, 0); - load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( + load_helper( + src, src_ptr + iw, 0); + load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<2, 2, c_dim, Vfmaq_laneq_f32>(c, src, weight); // row 2 - load_helper<5, 0, simd_len, 0, Vld1q_f32>(src, src_ptr + 2 * iw, 0); - load_helper<3, 0, oc_step, c_dim, Vld1q_f32>( + load_helper( + src, src_ptr + 2 * iw, 0); + load_helper( weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); @@ -238,6 +241,52 @@ struct KerNeonXXs2NchwNchw44FP32 { } }; +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int loop_ic_step = 1; + constexpr int filter_size = 2; + constexpr int oc_step = 4; + constexpr int simd_len = 4; + constexpr int src_reg_size = 4; + + constexpr int ld_weight_fw = oc_step * filter_size; + const int ld_weight_oc = oc_step * filter_size * filter_size * ic; + const int ld_weight_ic = oc_step * filter_size * filter_size; + const int ld_src_ic = ih * iw; + constexpr int c_dim = OCHelper::val; + float32x4_t c[c_dim][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + float32x4_t src[src_reg_size]; + float32x4_t weight[c_dim][filter_size]; + // row 0 + load_helper(src, src_ptr, + 0); + load_helper( + weight, weight_ptr, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + + // row 1 + load_helper( + src, src_ptr + iw, 0); + load_helper( + weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); + cal_helper<0, 0, c_dim, Vfmaq_laneq_f32>(c, src, weight); + cal_helper<1, 1, c_dim, Vfmaq_laneq_f32>(c, src, weight); + + src_ptr += ld_src_ic; + weight_ptr += ld_weight_ic; + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + } // namespace void conv_bias::pack_weight_fp32_nchw_nchw44(const float32_t* in_ptr, @@ -383,19 +432,12 @@ static void conv_direct_stride2_fp32_nchw_nchw44( ow, op, ph, pw); \ } +CONSTRUCT_FUNC(2); CONSTRUCT_FUNC(3); CONSTRUCT_FUNC(5); CONSTRUCT_FUNC(7); #undef CONSTRUCT_FUNC -template -void conv_bias::conv_direct_stride2_2x2_fp32_nchw_nchw44( - const float32_t*, const float32_t*, const float32_t*, float32_t*, - float32_t*, const int, const int, const int, const int, const int, - const int, const int, const Op&, const int, const int) { - megdnn_assert(0, "not imple nchw_nchw44 2x2s2 conv"); -} - #define INSTANTIATION(stride, i, bias, Op) \ template void conv_bias:: \ conv_direct_##stride##_##i##x##i##_fp32_nchw_nchw44( \ diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 11badbbc5db0ba12d60a65dc702ce5290248d3b2..6289fff717c2fe0918820cabb1052083a6743c4a 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -195,6 +195,7 @@ static void benchmark_convbias(Handle* handle, bool is_fp32 = false) { }; if (is_fp32) { + run(1, 1, 4, 112, 112, 2, 2, true); run(1, 3, 32, 224, 224, 3, 2, true); run(1, 3, 64, 224, 224, 7, 2, true); } else { @@ -1806,12 +1807,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { auto opr = handle()->create_operator(); opr->param() = arg.param; opr->deduce_layout({arg.src, dtype::Float32()}, - {arg.filter, dtype::Float32()}, - {arg.bias, dtype::Float32()}, {}, dst_layout); + {arg.filter, dtype::Float32()}, + {arg.bias, dtype::Float32()}, {}, dst_layout); //! dst.nr_elems * IC * FH * FW * 2 float computations = dst_layout.total_nr_elems() * arg.filter[1] * - arg.filter[2] * arg.filter[3] * 2.0 / - (1024 * 1024 * 1024) * 1e3; + arg.filter[2] * arg.filter[3] * 2.0 / + (1024 * 1024 * 1024) * 1e3; benchmark_im2col.set_param(arg.param); auto im2col_used = @@ -1828,11 +1829,11 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { RUN; printf("%s %s: im2col: %f ms %f Gflops winograd: %f ms %f GFlops " - "speedup: " - "%f\n", - arg.src.to_string().c_str(), arg.filter.to_string().c_str(), - im2col_used, computations / im2col_used, winograd_used, - computations / winograd_used, im2col_used / winograd_used); + "speedup: " + "%f\n", + arg.src.to_string().c_str(), arg.filter.to_string().c_str(), + im2col_used, computations / im2col_used, winograd_used, + computations / winograd_used, im2col_used / winograd_used); } } diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index 7740c329a28f5458b67cdde443fbf0119b4d1ac1..b11d3df9d1cb71c90a64cdf22218c98ab370fa1f 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -342,9 +342,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_DIRECT_FP32_STR2_SMALL_GROUP) { handle(), "F32STRD2_SMALL_GROUP"); } TEST_F(ARM_COMMON_MULTI_THREADS, CONVBIAS_NCHW_NCHW44_F32) { - check_conv_bias( - get_nchw44_conv_bias_args({3, 5, 7}, 2, false, false, false, true), - handle(), "F32_CONV_NCHW_NCHW44"); + check_conv_bias(get_nchw44_conv_bias_args({2, 3, 5, 7}, 2, false, false, + false, true), + handle(), "F32_CONV_NCHW_NCHW44"); } /**********************************F16 direct************************/ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC