diff --git a/lite/backends/arm/math/pooling.cc b/lite/backends/arm/math/pooling.cc index aff4e56124319016b7014f874e5281b61526e0a9..c3652217ededa10b57e211ba7f5d3dc76e235978 100644 --- a/lite/backends/arm/math/pooling.cc +++ b/lite/backends/arm/math/pooling.cc @@ -1132,6 +1132,11 @@ void pooling2x2s2p0_max(const float* din, float* dr_out = data_out_channel; auto dr0 = r0; auto dr1 = r1; + if (h * S + K - P > hin + 1) { + memset(dr_out, 0.f, sizeof(float) * wout); + data_out_channel += wout; + continue; + } if (h * S + K - P > hin) { dr1 = r0; } @@ -1164,7 +1169,7 @@ void pooling2x2s2p0_max(const float* din, int wstart = 0; for (int j = 0; j < w_unroll_remian; ++j) { int wend = std::min(wstart + K, rem); - float tmp = dr0[wstart]; + float tmp = wstart < rem ? dr0[wstart] : 0.f; for (int i = wstart; i < wend; i++) { tmp = std::max(tmp, dr0[i]); tmp = std::max(tmp, dr1[i]); @@ -1222,9 +1227,20 @@ void pooling2x2s2p0_avg(const float* din, float* dr_out = data_out_channel; auto dr0 = r0; auto dr1 = r1; + if (h * S + K - P > hin + 1) { + memset(dr_out, 0.f, sizeof(float) * wout); + data_out_channel += wout; + continue; + } if (h * S + K - P > hin) { dr1 = zero_ptr; - vcoef = vdupq_n_f32(0.5f); + if (exclusive) { + vcoef = vdupq_n_f32(0.5f); + } else { + if (pad_bottom == 0) { + vcoef = vdupq_n_f32(0.5f); + } + } } int cnt_num = w_unroll_size; if (w_unroll_size > 0) { @@ -1257,11 +1273,20 @@ void pooling2x2s2p0_avg(const float* din, int wend = std::min(wstart + K, rem); float coef = 0.25f; float tmp = 0.f; - if (wend - wstart == 1 && pad_right == 0) { - coef *= 2; - } - if (h * S + K - P > hin && pad_bottom == 0) { - coef *= 2; + if (exclusive) { + if (wend - wstart == 1) { + coef *= 2; + } + if (h * S + K - P > hin) { + coef *= 2; + } + } else { + if (wend - wstart == 1 && pad_right == 0) { + coef *= 2; + } + if (h * S + K - P > hin && pad_bottom == 0) { + coef *= 2; + } } for (int i = wstart; i < wend; i++) { tmp += dr0[i] + dr1[i]; @@ -1442,7 +1467,7 @@ void pooling2x2s2p1_avg(const float* din, } if (h * S + K - P > hin) { dr1 = zero_ptr; - if (exclusive) { + if (exclusive || pad_bottom == 0) { coef_h = 1.f; } if (h * S + K - P > hin + 1) { @@ -1490,8 +1515,14 @@ void pooling2x2s2p1_avg(const float* din, int st = wstart > 0 ? wstart : 0; float tmp = 0.f; float coef = coef_h / 2; - if (exclusive && wend - st == 1) { - coef = coef_h; + if (exclusive) { + if (wend - st == 1) { + coef = coef_h; + } + } else { + if (wend - st == 1 && wstart > 0 && pad_right == 0) { + coef = coef_h; + } } for (int i = 0; i < wend - st; i++) { tmp += dr0[i] + dr1[i]; @@ -2193,13 +2224,7 @@ void pooling3x3s2p1_max(const float* din, w_unroll_size -= 1; w_unroll_remian = wout - w_unroll_size * 4; } - int w_needed = wout * 2 + 1; - int pad_right_ = w_needed - win - pad_bottom; - int w_2 = pad_right_ > 0 ? w_unroll_remian : w_unroll_remian + 1; - w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2; - - float minval = std::numeric_limits::lowest(); - float32x4_t vmin = vdupq_n_f32(minval); + float32x4_t vmin = vdupq_n_f32(std::numeric_limits::lowest()); for (int n = 0; n < num; ++n) { float* data_out_batch = data_out + n * chout * size_channel_out; @@ -2238,11 +2263,6 @@ void pooling3x3s2p1_max(const float* din, break; } } - - auto pr0 = dr0; - auto pr1 = dr1; - auto pr2 = dr2; - int cnt_num = w_unroll_size; if (w_unroll_size > 0) { #ifdef __aarch64__ @@ -2296,53 +2316,27 @@ void pooling3x3s2p1_max(const float* din, "q11", "q15"); #endif - dr0 -= 8; dr1 -= 8; dr2 -= 8; - } else { - float tmp = minval; - for (int i = 0; i < 2; i++) { - tmp = std::max(tmp, std::max(dr0[i], dr1[i])); - tmp = std::max(tmp, dr2[i]); - } - - dr_out[0] = tmp; - dr0++; - dr1++; - dr2++; - dr_out++; - } - - for (int w = 0; w < w_2 - 1; w += 1) { - float32x4_t vr0 = vld1q_f32(dr0); - float32x4_t vr1 = vld1q_f32(dr1); - float32x4_t vr2 = vld1q_f32(dr2); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - vr2 = vsetq_lane_f32(minval, vr2, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - vmax1 = vmaxq_f32(vmax1, vr2); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - float32x2_t vmax = vpmax_f32(vmax2, vmax2); - dr_out[0] = vget_lane_f32(vmax, 0); - dr_out++; - - dr0 += 2; - dr1 += 2; - dr2 += 2; } - - if (pad_right_) { - float tmp = minval; - for (int i = 1; i < 3; i++) { - tmp = std::max(tmp, std::max(pr0[win - i], pr1[win - i])); - tmp = std::max(tmp, pr2[win - i]); + // deal with right pad + int wstart = w_unroll_size * 4 * S - P; + for (int j = 0; j < w_unroll_remian; ++j) { + int wend = std::min(wstart + K, win); + int st = wstart > 0 ? wstart : 0; + float tmp = dr0[0]; + for (int i = 0; i < wend - st; i++) { + tmp = std::max(tmp, dr0[i]); + tmp = std::max(tmp, dr1[i]); + tmp = std::max(tmp, dr2[i]); } - dr_out[0] = tmp; + *(dr_out++) = tmp; + dr0 += S - (st - wstart); + dr1 += S - (st - wstart); + dr2 += S - (st - wstart); + wstart += S; } - data_out_channel += wout; } } @@ -2575,11 +2569,10 @@ void pooling3x3s2p0_max(const float* din, int remain = w_unroll_remian - 1; int right = wout * 2 + 1 - win; // if need right pad - - int w_2 = right > 0 ? w_unroll_remian : w_unroll_remian + 1; - w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2; + int tmp_val = (w_unroll_size * 4 + remain) * S; + int wend = std::min(tmp_val + K, win) - tmp_val; float minval = std::numeric_limits::lowest(); - + remain = right > 0 ? remain : remain + 1; for (int n = 0; n < num; ++n) { float* data_out_batch = data_out + n * chout * size_channel_out; const float* data_in_batch = data_in + n * chin * size_channel_in; @@ -2630,88 +2623,59 @@ void pooling3x3s2p0_max(const float* din, "v9", "v10", "v11"); +#else + asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX + : [dr0] "+r"(dr0), + [dr1] "+r"(dr1), + [dr2] "+r"(dr2), + [dr_out] "+r"(dr_out), + [cnt_num] "+r"(cnt_num) + : + : "cc", + "memory", + "q0", + "q1", + "q2", + "q3", + "q4", + "q5", + "q6", + "q7", + "q8", + "q9", + "q10", + "q11"); +#endif dr0 -= 8; dr1 -= 8; dr2 -= 8; - - for (int w = 0; w < w_2 - 1; w += 1) { - float32x4_t vr0 = vld1q_f32(dr0); - float32x4_t vr1 = vld1q_f32(dr1); - float32x4_t vr2 = vld1q_f32(dr2); - vr0 = vsetq_lane_f32(minval, vr0, 3); - vr1 = vsetq_lane_f32(minval, vr1, 3); - vr2 = vsetq_lane_f32(minval, vr2, 3); - float32x4_t vmax1 = vmaxq_f32(vr0, vr1); - vmax1 = vmaxq_f32(vmax1, vr2); - float32x2_t vmax2 = - vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); - float32x2_t vmax = vpmax_f32(vmax2, vmax2); - dr_out[0] = vget_lane_f32(vmax, 0); - dr_out++; - dr0 += 2; - dr1 += 2; - dr2 += 2; - } -#else - asm volatile( - P3x3S2P0_INIT P3x3S2P0_MAX - "cmp %[remain], #0 @cmp cnt_num\n" - "sub %[dr0], #32 @sub - 8\n" - "sub %[dr1], #32 @sub - 8\n" - "sub %[dr2], #32 @sub - 8\n" - "ble 4f @ble exit1\n" - "2: @mid loop\n" - "vld1.f32 {d0-d1}, [%[dr0]]! @load \n" - "vld1.f32 {d2-d3}, [%[dr1]]! @load \n" - "vld1.f32 {d4-d5}, [%[dr2]]! @load \n" - "vmov.f32 s3,s2 @mov \n" - "vmov.f32 s7,s6 @mov \n" - "vmov.f32 s11,s10 @mov \n" - "vmax.f32 q0, q0, q1 @max n" - "sub %[dr0], #8 @add w \n" - "sub %[dr1], #8 @add w \n" - "sub %[dr2], #8 @add w \n" - "vmax.f32 q0, q0, q2 @max \n" - "vpmax.f32 d0, d0, d1 @pmax \n" - "vpmax.f32 d0, d0, d0 @pmax \n" - "subs %[remain], #1 @subs \n" - "vst1.f32 d0[0], [%[dr_out]]! @vst \n" - "bne 2b @bne \n" - "4: @exit\n" - : [dr0] "+r"(dr0), - [dr1] "+r"(dr1), - [dr2] "+r"(dr2), - [dr_out] "+r"(dr_out), - [remain] "+r"(cnt_remain), - [cnt_num] "+r"(cnt_num) - : - : "cc", - "memory", - "q0", - "q1", - "q2", - "q3", - "q4", - "q5", - "q6", - "q7", - "q8", - "q9", - "q10", - "q11"); - if (right) { - int wstart = (w_unroll_size * 4 + remain) * S; - int wend = std::min(wstart + K, win); - float tmp = dr0[wstart]; // std::numeric_limits::min(); - for (int i = wstart; i < wend; i++) { - tmp = std::max(tmp, std::max(dr0[i], dr1[i])); - tmp = std::max(tmp, dr2[i]); - } - *(dr_out++) = tmp; + } + for (int w = 0; w < remain; w += 1) { + float32x4_t vr0 = vld1q_f32(dr0); + float32x4_t vr1 = vld1q_f32(dr1); + float32x4_t vr2 = vld1q_f32(dr2); + vr0 = vsetq_lane_f32(minval, vr0, 3); + vr1 = vsetq_lane_f32(minval, vr1, 3); + vr2 = vsetq_lane_f32(minval, vr2, 3); + float32x4_t vmax1 = vmaxq_f32(vr0, vr1); + vmax1 = vmaxq_f32(vmax1, vr2); + float32x2_t vmax2 = + vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1)); + float32x2_t vmax = vpmax_f32(vmax2, vmax2); + dr_out[0] = vget_lane_f32(vmax, 0); + dr_out++; + dr0 += 2; + dr1 += 2; + dr2 += 2; + } + if (right) { + float tmp = dr0[0]; // std::numeric_limits::min(); + for (int i = 0; i < wend; i++) { + tmp = std::max(tmp, std::max(dr0[i], dr1[i])); + tmp = std::max(tmp, dr2[i]); } -#endif + *(dr_out++) = tmp; } - r0 = r2; r1 = r0 + win; r2 = r1 + win; @@ -2748,6 +2712,7 @@ void pooling3x3s2p0_avg(const float* din, w_unroll_size -= 1; w_unroll_remian = wout - w_unroll_size * 4; } + // do overflow process w_unroll_size -= 1; w_unroll_remian += 4; @@ -2861,6 +2826,7 @@ void pooling3x3s2p0_avg(const float* din, dr2 -= 8; } // deal with right pad + w_unroll_size = w_unroll_size < 0 ? 0 : w_unroll_size; int wstart = w_unroll_size * 4 * S - P; for (int j = 0; j < w_unroll_remian; ++j) { int wend = wstart + K; // std::min(wstart + K, win); diff --git a/lite/core/mir/fusion/conv_conv_fuser.cc b/lite/core/mir/fusion/conv_conv_fuser.cc index 267770e047a21035bda7cca4d4d54c48e5ffc89d..2e369774957b036428092f562622ade9b77ceb41 100644 --- a/lite/core/mir/fusion/conv_conv_fuser.cc +++ b/lite/core/mir/fusion/conv_conv_fuser.cc @@ -108,10 +108,12 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { if (!(kw == 1 && kh == 1)) { LOG(FATAL) << "The kernel size of the second conv must be 1x1"; } + auto channel0_out = weight0_t->dims()[0]; + auto channel1_in = weight1_t->dims()[1] * groups1; CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same"; CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1"; - CHECK_EQ(weight0_t->dims()[0], weight1_t->dims()[1]) - << "weight0_dims[0] == weight1_dim[1]"; + CHECK_EQ(channel0_out, channel1_in) << "channel0_out == channel1_in"; + for (int i = 0; i < strides1.size(); i++) { CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i] << " must be 1"; diff --git a/lite/kernels/arm/pool_compute.cc b/lite/kernels/arm/pool_compute.cc index 5cfca8f1b7d9a286d24dda5af5664aa381c8e0f1..ae993b52372305a252daeda4280edbce9b2965ce 100644 --- a/lite/kernels/arm/pool_compute.cc +++ b/lite/kernels/arm/pool_compute.cc @@ -57,7 +57,9 @@ void PoolCompute::Run() { (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less; bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && (ksize[1] == in_dims[3]) && kps_equal && pads_equal; + bool win_ksize = (in_dims[2] > ksize[0]) && (in_dims[3] > ksize[1]); global_pooling = param.global_pooling || global_pooling; + kps_equal = kps_equal && win_ksize; if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { diff --git a/lite/tests/math/pool_compute_test.cc b/lite/tests/math/pool_compute_test.cc index e0d4de61747d5772edd94f7ad66cfe99e8cf0457..890973a0dfbb47d3943e7c463ef6fc7092fd367f 100644 --- a/lite/tests/math/pool_compute_test.cc +++ b/lite/tests/math/pool_compute_test.cc @@ -435,7 +435,7 @@ TEST(TestPoolRand, test_pool_rand) { adaptive, use_quantizer, pooling_type, - {1, 2, 4}, + {4}, {FLAGS_power_mode}); } }