diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp index 5b00b5dd0c5f77c249e1e9a0c5bdd1a27234ee66..89350047b77d6209f6d156e756a3e1876cb6e902 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp @@ -32,29 +32,37 @@ constexpr size_t pack_size = 4; struct InputTransformF23_NCHW44 { template - static void prepare(const float* input, float* patch, float* patchT, - int ih_start, int iw_start, size_t IH, size_t IW, - size_t ic, size_t IC) { - MEGDNN_MARK_USED_VAR(patch); + static void transform(float* patchT, const float* input, + float* input_transform_buf, size_t ih_start, + size_t iw_start, size_t IH, size_t IW, + size_t unit_idx, size_t nr_units_in_tile, size_t ic, + size_t IC) { size_t IW4 = IW * pack_size; - size_t iw4_start = iw_start * pack_size; size_t icb = ic / pack_size; + size_t iw4_start = iw_start * pack_size; + size_t ICB = IC / pack_size; + +#define cb(m, n) Vector d##m##n; + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#undef cb + if (!(inner && ic + pack_size < IC)) { memset(patchT, 0, sizeof(float) * pack_size * alpha * alpha); } if (inner) { + MEGDNN_MARK_USED_VAR(patchT); const float* input_ptr = input + icb * IH * IW4 + ih_start * IW4 + iw4_start; - for (size_t ih = 0; ih < alpha; ih++) { -#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i); - UNROLL_CALL_NOWRAPPER(4, cb); -#undef cb - -#define cb(i) vst1q_f32(patchT + ih * alpha * pack_size + i * pack_size, v##i); - UNROLL_CALL_NOWRAPPER(4, cb); +#define cb(n, m) d##m##n = Vector::load(input_ptr + pack_size * n); + + UNROLL_CALL_RAW(4, cb, 0); + input_ptr += IW4; + UNROLL_CALL_RAW(4, cb, 1); + input_ptr += IW4; + UNROLL_CALL_RAW(4, cb, 2); + input_ptr += IW4; + UNROLL_CALL_RAW(4, cb, 3); #undef cb - input_ptr += IW4; - } } else { int ih0_act = std::max(ih_start, 0), ih1_act = std::min(ih_start + alpha, IH), @@ -71,19 +79,12 @@ struct InputTransformF23_NCHW44 { src); } } - } - } - - static void transform(const float* patchT, float* input_transform_buf, - size_t unit_idx, size_t nr_units_in_tile, size_t ic, - size_t IC) { - // BT * d * B -#define cb(m, n) \ - Vector d##m##n = Vector::load( \ - patchT + m * alpha * pack_size + n * pack_size); - UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); +#define cb(m, n) \ + d##m##n = Vector::load(patchT + m * alpha * pack_size + \ + n * pack_size); + UNROLL_CALL_NOWRAPPER_D2(4, 4, cb); #undef cb - + } //! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0 //! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1 //! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0 @@ -106,8 +107,6 @@ struct InputTransformF23_NCHW44 { UNROLL_CALL_NOWRAPPER(4, cb); #undef cb - size_t ICB = IC / 4; - size_t icb = ic / 4; #define cb(m, n) \ d##m##n.save(input_transform_buf + \ (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ @@ -273,7 +272,6 @@ void winograd_F23_mk4_f_nchw44::input(const float* input, // OW = IW + 2 * PW - KERNEL_SIZE + 1 auto units_w = div_ceil(IW + 2 * PW - KERNEL_SIZE + 1, OUTPUT_BLOCK_SIZE); - float* patch = transform_mid_buf; float* patchT = transform_mid_buf + 4 * alpha * alpha; for (size_t ic = 0; ic < IC; ic += 4) { @@ -285,20 +283,13 @@ void winograd_F23_mk4_f_nchw44::input(const float* input, int iw_start = nw * OUTPUT_BLOCK_SIZE - PW; if (ih_start >= 0 && ih_start + alpha <= static_cast(IH) && iw_start >= 0 && iw_start + alpha <= static_cast(IW)) { - InputTransformF23_NCHW44::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransformF23_NCHW44::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, - ic, IC); - + InputTransformF23_NCHW44::transform( + patchT, input, input_transform_buf, ih_start, iw_start, + IH, IW, unit_idx, nr_units_in_tile, ic, IC); } else { - InputTransformF23_NCHW44::prepare(input, patch, patchT, - ih_start, iw_start, IH, - IW, ic, IC); - InputTransformF23_NCHW44::transform(patchT, input_transform_buf, - unit_idx, nr_units_in_tile, - ic, IC); + InputTransformF23_NCHW44::transform( + patchT, input, input_transform_buf, ih_start, iw_start, + IH, IW, unit_idx, nr_units_in_tile, ic, IC); } } } @@ -311,9 +302,21 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf, size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, size_t nr_units_in_tile) { -#define cb(_bmode, _nonline_op, ...) \ - OutputTransformF23_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ - __VA_ARGS__); +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += 4) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF23_NCHW44<_bmode, _nonline_op>::transform( \ + output_transform_buf, bias, output, transform_mid_buf, \ + oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, \ + unit_idx, nr_units_in_tile, src_dtype, dst_dtype); \ + } \ + } auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); constexpr size_t pack_size = 4; @@ -323,22 +326,8 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf, oc_end % pack_size == 0, "NCHW44 Winograd filter transform requires OC is times of 4"); - for (size_t oc = oc_start; oc < oc_end; oc += 4) { - size_t oc_index = oc - oc_start; - rep(unit_idx, nr_units_in_tile) { - size_t index = unit_start_idx + unit_idx; - auto nh = index / units_w; - auto nw = index % units_w; - size_t oh_start = nh * OUTPUT_BLOCK_SIZE; - size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, cb, float, - float, bmode, nonline_mode, output_transform_buf, bias, - output, transform_mid_buf, oh_start, ow_start, OH, OW, - oc_start, oc_end, oc_index, unit_idx, nr_units_in_tile, - src_dtype, dst_dtype); - } - } + DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_nchw44_fp32_F23_mk4, + cb, float, float, bmode, nonline_mode); #undef cb } diff --git a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp index df5aa713ef7680e3ce0df37968536a559751f751..b861a23dcdfb82662bcdca04c0277ad3ceda5dcd 100644 --- a/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp @@ -31,6 +31,8 @@ namespace { constexpr size_t alpha = 6 + 3 - 1; constexpr size_t pack_size = 4; +constexpr float input_parameters[12] = {5.25f, 4.25f, 0.5f, 0.25f, 2.5f, 1.25f, + 2.0f, 4.0f, 5.0f, 0.0f, 0.0f, 0.0f}; struct InputTransformF63_NCHW44 { template @@ -80,12 +82,14 @@ struct InputTransformF63_NCHW44 { size_t unit_idx, size_t nr_units_in_tile, size_t ic, size_t IC) { // BT * d * B -#define cb(m, n) \ - Vector d##m##n = Vector::load( \ - patchT + m * alpha * pack_size + n * pack_size); - UNROLL_CALL_NOWRAPPER_D2(8, 8, cb); -#undef cb + size_t ICB = IC / pack_size; + size_t icb = ic / pack_size; + + float32x4_t d0, d1, d2, d3, d4, d5, d6, d7; + float32x4_t v0 = vld1q_f32(input_parameters + 0); + float32x4_t v1 = vld1q_f32(input_parameters + 4); + float32x4_t v2 = vld1q_f32(input_parameters + 8); //! B //! 1 0 0 0 0 0 0 0 @@ -96,49 +100,147 @@ struct InputTransformF63_NCHW44 { //! 0 1 -1 2 -2 0.5 -0.5 -5.25 //! -1 1 1 1 1 1 1 0 //! 0 0 0 0 0 0 0 1 -#define cb(m) \ - auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \ - auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \ - auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \ - auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \ - d5##m * 2.f + d6##m; \ - auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \ - d4##m * 1.25f - d5##m * 2.f + d6##m; \ - auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \ - d5##m * 0.5f + d6##m; \ - auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \ - d5##m * 0.5f + d6##m; \ - auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f; - UNROLL_CALL_NOWRAPPER(8, cb); -#undef cb - -#define cb(m) \ - d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \ - d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \ - (t##m##3 + t##m##4) * 4.25f; \ - d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \ - (t##m##3 - t##m##4) * 4.25f; \ - d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \ - t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \ - d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \ - t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \ - d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \ - t##m##5 * 0.5f + t##m##6; \ - d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \ - t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \ - d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f; - - UNROLL_CALL_NOWRAPPER(8, cb); +#define cb(i) \ + d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \ + d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \ + d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \ + d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \ + d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \ + d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \ + auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \ + auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \ + auto t##i##1 = d6; \ + auto t##i##2 = d6; \ + auto t##i##3 = d6; \ + auto t##i##4 = d6; \ + auto t##i##5 = d6; \ + auto t##i##6 = d6; \ + t##i##0 = t##i##0 - d6; \ + t##i##1 = t##i##1 + d1; \ + t##i##2 = t##i##2 - d1; \ + t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \ + t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \ + t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \ + t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \ + t##i##7 = t##i##7 - d1; \ + t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \ + t##i##1 = t##i##1 + d2; \ + t##i##2 = t##i##2 + d2; \ + t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \ + t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \ + t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \ + t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \ + t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \ + t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \ + t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \ + t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \ + t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \ + t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \ + t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \ + t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \ + t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \ + t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \ + t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \ + t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \ + t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \ + t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \ + t##i##1 = t##i##1 + d5; \ + t##i##2 = t##i##2 - d5; \ + t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \ + t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \ + t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \ + t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \ + t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0); + UNROLL_CALL_RAW(8, cb); #undef cb - size_t ICB = IC / pack_size; - size_t icb = ic / pack_size; -#define cb(m, n) \ - d##m##n.save(input_transform_buf + \ - (m * alpha + n) * ICB * nr_units_in_tile * pack_size + \ - icb * nr_units_in_tile * pack_size + unit_idx * pack_size); - UNROLL_CALL_NOWRAPPER_D2(8, 8, cb) +#define cb(i) \ + d0 = t0##i; \ + d1 = t6##i; \ + d2 = t6##i; \ + d3 = t6##i; \ + d4 = t6##i; \ + d5 = t6##i; \ + d6 = t6##i; \ + d7 = t7##i; \ + d0 = d0 - t6##i; \ + d1 = d1 + t1##i; \ + d2 = d2 - t1##i; \ + d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \ + d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \ + d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \ + d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \ + d7 = d7 - t1##i; \ + d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \ + d1 = d1 + t2##i; \ + d2 = d2 + t2##i; \ + d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \ + d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \ + d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \ + d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \ + d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \ + d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \ + d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \ + d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \ + d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \ + d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \ + d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \ + d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \ + d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \ + d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \ + d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \ + d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \ + d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \ + d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \ + d1 = d1 + t5##i; \ + d2 = d2 - t5##i; \ + d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \ + d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \ + d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \ + d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \ + d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \ + vst1q_f32(input_transform_buf + \ + (0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d0); \ + vst1q_f32(input_transform_buf + \ + (1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d1); \ + vst1q_f32(input_transform_buf + \ + (2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d2); \ + vst1q_f32(input_transform_buf + \ + (3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d3); \ + vst1q_f32(input_transform_buf + \ + (4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d4); \ + vst1q_f32(input_transform_buf + \ + (5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d5); \ + vst1q_f32(input_transform_buf + \ + (6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d6); \ + vst1q_f32(input_transform_buf + \ + (7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \ + icb * nr_units_in_tile * pack_size + \ + unit_idx * pack_size, \ + d7); + UNROLL_CALL_RAW(8, cb); #undef cb } }; @@ -178,7 +280,7 @@ struct OutputTransformF63_NCHW44 { * 1 -2 4 -8 16 -32 * 1 0.5 0.25 0.125 0.0625 0.03125 * 1 -0.5 0.25 -0.125 0.0625 -0.03125 - * 0 0.0 0 0 0 1 + * 0 0 0 0 0 1 */ Vector v1addv2, v1subv2, v3addv4, v3subv4, v5addv6, v5subv6; @@ -378,28 +480,33 @@ void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf, size_t OW, size_t oc_start, size_t oc_end, size_t unit_start_idx, size_t nr_units_in_tile) { - constexpr size_t pack_size = 4; -#define cb(_bmode, _nonline_op, ...) \ - OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \ - __VA_ARGS__); +#define cb(_bmode, _nonline_op, ...) \ + for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \ + size_t oc_index = oc - oc_start; \ + rep(unit_idx, nr_units_in_tile) { \ + size_t index = unit_start_idx + unit_idx; \ + auto nh = index / units_w; \ + auto nw = index % units_w; \ + size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \ + size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \ + OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \ + transform(output_transform_buf, bias, output, \ + transform_mid_buf, oh_start, ow_start, OH, OW, \ + oc_start, oc_end, oc_index, unit_idx, \ + nr_units_in_tile, src_dtype, dst_dtype); \ + } \ + } auto units_w = div_ceil(OW, OUTPUT_BLOCK_SIZE); - for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { - size_t oc_index = oc - oc_start; - rep(unit_idx, nr_units_in_tile) { - size_t index = unit_start_idx + unit_idx; - auto nh = index / units_w; - auto nw = index % units_w; - size_t oh_start = nh * OUTPUT_BLOCK_SIZE; - size_t ow_start = nw * OUTPUT_BLOCK_SIZE; - DISPATCH_CONV_WINOGRAD_BIAS( - megdnn_arm_common_winograd_fp32_F63_mk4, cb, float, float, - bmode, nonline_mode, output_transform_buf, bias, output, - transform_mid_buf, oh_start, ow_start, OH, OW, oc_start, - oc_end, oc_index, unit_idx, nr_units_in_tile, src_dtype, - dst_dtype); - } - } + constexpr size_t pack_size = 4; + + size_t OC = oc_end - oc_start; + megdnn_assert(OC % pack_size == 0 && oc_start % pack_size == 0 && + oc_end % pack_size == 0, + "NCHW44 Winograd filter transform requires OC is times of 4"); + + DISPATCH_CONV_WINOGRAD_BIAS(megdnn_arm_common_winograd_fp32_F63_mk4, cb, + float, float, bmode, nonline_mode); #undef cb } diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 9deffc4d24c56031ee0acb84aa0e53dfdf5c578e..b07d87c0a4b89245d4837f058da83e062ce357f8 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -538,10 +538,43 @@ struct Vfmaq_laneq_f32_armv7<3> { return vmlaq_lane_f32(a, b, vget_high_f32(v), 1); } }; + +template +struct Vfmsq_laneq_f32_armv7 { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); +}; + +template <> +struct Vfmsq_laneq_f32_armv7<0> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlsq_lane_f32(a, b, vget_low_f32(v), 0); + } +}; +template <> +struct Vfmsq_laneq_f32_armv7<1> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlsq_lane_f32(a, b, vget_low_f32(v), 1); + } +}; +template <> +struct Vfmsq_laneq_f32_armv7<2> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlsq_lane_f32(a, b, vget_high_f32(v), 0); + } +}; +template <> +struct Vfmsq_laneq_f32_armv7<3> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + return vmlsq_lane_f32(a, b, vget_high_f32(v), 1); + } +}; } // namespace #define vfmaq_laneq_f32(a, b, v, lane) \ Vfmaq_laneq_f32_armv7::impl(a, b, v) +#define vfmsq_laneq_f32(a, b, v, lane) \ + Vfmsq_laneq_f32_armv7::impl(a, b, v) + #if __ARM_FEATURE_DOTPROD namespace { template @@ -582,7 +615,6 @@ struct Vdotq_laneq_s32_armv7<3> { //! GCC split fmla with lane to dup+fmla when version < 9 //! https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 -#if !defined(__clang__) && __GNUC__ < 9 #if MEGDNN_AARCH64 namespace { @@ -630,13 +662,59 @@ struct Vfmaq_laneq_f32_armv8<3> { return a; } }; + +template +struct Vfmsq_laneq_f32_armv8 { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v); +}; +template <> +struct Vfmsq_laneq_f32_armv8<0> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + asm volatile("fmls %0.4s, %1.4s, %2.s[0]\n" + : "+w"(a) + : "w"(b), "w"(v) + :); + return a; + } +}; +template <> +struct Vfmsq_laneq_f32_armv8<1> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + asm volatile("fmls %0.4s, %1.4s, %2.s[1]\n" + : "+w"(a) + : "w"(b), "w"(v) + :); + return a; + } +}; +template <> +struct Vfmsq_laneq_f32_armv8<2> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + asm volatile("fmls %0.4s, %1.4s, %2.s[2]\n" + : "+w"(a) + : "w"(b), "w"(v) + :); + return a; + } +}; +template <> +struct Vfmsq_laneq_f32_armv8<3> { + __ai float32x4_t impl(float32x4_t a, float32x4_t b, float32x4_t v) { + asm volatile("fmls %0.4s, %1.4s, %2.s[3]\n" + : "+w"(a) + : "w"(b), "w"(v) + :); + return a; + } +}; } // namespace #undef vfmaq_laneq_f32 #define vfmaq_laneq_f32(a, b, v, lane) \ Vfmaq_laneq_f32_armv8::impl(a, b, v) -#endif - +#undef vfmsq_laneq_f32 +#define vfmsq_laneq_f32(a, b, v, lane) \ + Vfmsq_laneq_f32_armv8::impl(a, b, v) #endif __ai int8x16_t vld_dup_tbl_s32(const int8_t* ptr, uint8x16_t& idx) { @@ -678,6 +756,16 @@ __ai int16x8_t vld1_dup_s8_s16(const int8_t* ptr) { return vmovl_s8(vld1_dup_s8(ptr)); } +//! we add this because we found that cpu=aarch64_android cann't compile fmsq into fmls. +//! it use dup+fmla instead +__ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { + asm volatile("fmls %0.4s, %1.4s, %2.4s\n" + : "+w"(a) + : "w"(b), "w"(v) + :); + return a; +} + #undef __ai #pragma GCC diagnostic pop diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 79f8a55dab1e458ef451f87f82705721ed1e2526..f89bbe23caf4cb28c6a08c7c9e5ad5afd30f6554 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -791,8 +791,8 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) { std::vector nonlinemode = {NLMode::IDENTITY}; for (auto nlmode : nonlinemode) - for (size_t n : {1, 2}) - for (size_t group = 1; group <= 2; ++group) { + for (size_t n : {1}) + for (size_t group = 1; group <= 1; ++group) { pack(n, 512, 512, 15, 15, group, nlmode); pack(n, 512, 256, 15, 15, group, nlmode); pack(n, 256, 256, 29, 29, group, nlmode);