diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index 8303851ece9dd2f1d053f9f4b888e42f2fdc0aad..f9bf52d00a7f91b78996bed132a7b91c69fa46c7 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -2235,6 +2235,7 @@ void pooling3x3s2p1_max(const float* din, int cnt_num = w_unroll_size; if (w_unroll_size > 0) { #ifdef __aarch64__ +#if 0 asm volatile( /* preocess left */ P3x3S2_INIT P3x3S2P1_MAX P3x3S2P0_MAX "2: \n" /* end */ @@ -2259,6 +2260,81 @@ void pooling3x3s2p1_max(const float* din, "v10", "v11", "v31"); +#else + + float32x4_t vr0_1234 = vld1q_f32(dr0); + float32x4_t vr0_5678 = vld1q_f32(dr0 += 4); + float32x4_t vr0_9101112 = vld1q_f32(dr0 += 4); + float32x4_t vr1_1234 = vld1q_f32(dr1); + float32x4_t vr1_5678 = vld1q_f32(dr1 += 4); + float32x4_t vr1_9101112 = vld1q_f32(dr1 += 4); + float32x4_t vr2_1234 = vld1q_f32(dr2); + float32x4_t vr2_5678 = vld1q_f32(dr2 += 4); + float32x4_t vr2_9101112 = vld1q_f32(dr2 += 4); + + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_0123 = vextq_f32(vmin, vmax_1234, 1); + float32x4_t vmax_4567 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_01_23 = + vpmax_f32(vget_low_f32(vmax_0123), vget_high_f32(vmax_0123)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_45_67 = + vpmax_f32(vget_low_f32(vmax_4567), vget_high_f32(vmax_4567)); + float32x2_t vmax_012_234 = vmax_f32(vmax_12_34, vmax_01_23); + float32x2_t vmax_456_678 = vmax_f32(vmax_56_78, vmax_45_67); + vst1_f32(dr_out, vmax_012_234); + dr_out += 2; + vst1_f32(dr_out, vmax_456_678); + dr_out += 2; + cnt_num--; + dr0--; + dr1--; + dr2--; + while (cnt_num--) { + float32x4_t vr0_1234 = vld1q_f32(dr0); + float32x4_t vr0_5678 = vld1q_f32(dr0 += 4); + float32x4_t vr0_9101112 = vld1q_f32(dr0 += 4); + float32x4_t vr1_1234 = vld1q_f32(dr1); + float32x4_t vr1_5678 = vld1q_f32(dr1 += 4); + float32x4_t vr1_9101112 = vld1q_f32(dr1 += 4); + float32x4_t vr2_1234 = vld1q_f32(dr2); + float32x4_t vr2_5678 = vld1q_f32(dr2 += 4); + float32x4_t vr2_9101112 = vld1q_f32(dr2 += 4); + float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234); + vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234); + float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678); + vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678); + float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112); + vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112); + float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1); + float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1); + float32x2_t vmax_12_34 = + vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234)); + float32x2_t vmax_23_45 = + vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345)); + float32x2_t vmax_56_78 = + vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678)); + float32x2_t vmax_67_89 = + vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789)); + float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45); + float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89); + vst1_f32(dr_out, vmax_123_345); + dr_out += 2; + vst1_f32(dr_out, vmax_567_789); + dr_out += 2; + } + dr0 += 8; + dr1 += 8; + dr2 += 8; +#endif #else asm volatile( /* preocess left */