未验证 提交 f61c4676 编写于 作者: H HappyAngel 提交者: GitHub

[arm] fix 3x3s2p0 max compute error (#4121)

* fix 3x3s2p0 max compute error when input_padding is not equal

* fix format. test=develop
上级 f3c93688
...@@ -1132,6 +1132,11 @@ void pooling2x2s2p0_max(const float* din, ...@@ -1132,6 +1132,11 @@ void pooling2x2s2p0_max(const float* din,
float* dr_out = data_out_channel; float* dr_out = data_out_channel;
auto dr0 = r0; auto dr0 = r0;
auto dr1 = r1; auto dr1 = r1;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0.f, sizeof(float) * wout);
data_out_channel += wout;
continue;
}
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = r0; dr1 = r0;
} }
...@@ -1164,7 +1169,7 @@ void pooling2x2s2p0_max(const float* din, ...@@ -1164,7 +1169,7 @@ void pooling2x2s2p0_max(const float* din,
int wstart = 0; int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) { for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem); int wend = std::min(wstart + K, rem);
float tmp = dr0[wstart]; float tmp = wstart < rem ? dr0[wstart] : 0.f;
for (int i = wstart; i < wend; i++) { for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]); tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]); tmp = std::max(tmp, dr1[i]);
...@@ -1222,10 +1227,21 @@ void pooling2x2s2p0_avg(const float* din, ...@@ -1222,10 +1227,21 @@ void pooling2x2s2p0_avg(const float* din,
float* dr_out = data_out_channel; float* dr_out = data_out_channel;
auto dr0 = r0; auto dr0 = r0;
auto dr1 = r1; auto dr1 = r1;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0.f, sizeof(float) * wout);
data_out_channel += wout;
continue;
}
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = zero_ptr; dr1 = zero_ptr;
if (exclusive) {
vcoef = vdupq_n_f32(0.5f);
} else {
if (pad_bottom == 0) {
vcoef = vdupq_n_f32(0.5f); vcoef = vdupq_n_f32(0.5f);
} }
}
}
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -1257,12 +1273,21 @@ void pooling2x2s2p0_avg(const float* din, ...@@ -1257,12 +1273,21 @@ void pooling2x2s2p0_avg(const float* din,
int wend = std::min(wstart + K, rem); int wend = std::min(wstart + K, rem);
float coef = 0.25f; float coef = 0.25f;
float tmp = 0.f; float tmp = 0.f;
if (exclusive) {
if (wend - wstart == 1) {
coef *= 2;
}
if (h * S + K - P > hin) {
coef *= 2;
}
} else {
if (wend - wstart == 1 && pad_right == 0) { if (wend - wstart == 1 && pad_right == 0) {
coef *= 2; coef *= 2;
} }
if (h * S + K - P > hin && pad_bottom == 0) { if (h * S + K - P > hin && pad_bottom == 0) {
coef *= 2; coef *= 2;
} }
}
for (int i = wstart; i < wend; i++) { for (int i = wstart; i < wend; i++) {
tmp += dr0[i] + dr1[i]; tmp += dr0[i] + dr1[i];
} }
...@@ -1442,7 +1467,7 @@ void pooling2x2s2p1_avg(const float* din, ...@@ -1442,7 +1467,7 @@ void pooling2x2s2p1_avg(const float* din,
} }
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = zero_ptr; dr1 = zero_ptr;
if (exclusive) { if (exclusive || pad_bottom == 0) {
coef_h = 1.f; coef_h = 1.f;
} }
if (h * S + K - P > hin + 1) { if (h * S + K - P > hin + 1) {
...@@ -1490,9 +1515,15 @@ void pooling2x2s2p1_avg(const float* din, ...@@ -1490,9 +1515,15 @@ void pooling2x2s2p1_avg(const float* din,
int st = wstart > 0 ? wstart : 0; int st = wstart > 0 ? wstart : 0;
float tmp = 0.f; float tmp = 0.f;
float coef = coef_h / 2; float coef = coef_h / 2;
if (exclusive && wend - st == 1) { if (exclusive) {
if (wend - st == 1) {
coef = coef_h; coef = coef_h;
} }
} else {
if (wend - st == 1 && wstart > 0 && pad_right == 0) {
coef = coef_h;
}
}
for (int i = 0; i < wend - st; i++) { for (int i = 0; i < wend - st; i++) {
tmp += dr0[i] + dr1[i]; tmp += dr0[i] + dr1[i];
} }
...@@ -2044,7 +2075,7 @@ void pooling3x3s1p0_avg(const float* din, ...@@ -2044,7 +2075,7 @@ void pooling3x3s1p0_avg(const float* din,
} else { } else {
if (pad_bottom > 1) { if (pad_bottom > 1) {
coef_h = 1.f / 3; coef_h = 1.f / 3;
} else if (pad_bottom = 1) { } else if (pad_bottom == 1) {
coef_h = 0.5f; coef_h = 0.5f;
} else { } else {
coef_h = 1.f; coef_h = 1.f;
...@@ -2538,7 +2569,10 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2538,7 +2569,10 @@ void pooling3x3s2p0_max(const float* din,
int remain = w_unroll_remian - 1; int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad int right = wout * 2 + 1 - win; // if need right pad
int tmp_val = (w_unroll_size * 4 + remain) * S;
int wend = std::min(tmp_val + K, win) - tmp_val;
float minval = std::numeric_limits<float>::lowest();
remain = right > 0 ? remain : remain + 1;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in; const float* data_in_batch = data_in + n * chin * size_channel_in;
...@@ -2589,53 +2623,12 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2589,53 +2623,12 @@ void pooling3x3s2p0_max(const float* din,
"v9", "v9",
"v10", "v10",
"v11"); "v11");
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]);
}
*(dr_out++) = tmp;
wstart += S;
}
#else #else
asm volatile( asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
P3x3S2P0_INIT P3x3S2P0_MAX
"cmp %[remain], #0 @cmp cnt_num\n"
"sub %[dr0], #32 @sub - 8\n"
"sub %[dr1], #32 @sub - 8\n"
"sub %[dr2], #32 @sub - 8\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load \n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load \n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load \n"
"vmov.f32 s3,s2 @mov \n"
"vmov.f32 s7,s6 @mov \n"
"vmov.f32 s11,s10 @mov \n"
"vmax.f32 q0, q0, q1 @max n"
"sub %[dr0], #8 @add w \n"
"sub %[dr1], #8 @add w \n"
"sub %[dr2], #8 @add w \n"
"vmax.f32 q0, q0, q2 @max \n"
"vpmax.f32 d0, d0, d1 @pmax \n"
"vpmax.f32 d0, d0, d0 @pmax \n"
"subs %[remain], #1 @subs \n"
"vst1.f32 d0[0], [%[dr_out]]! @vst \n"
"bne 2b @bne \n"
"4: @exit\n"
: [dr0] "+r"(dr0), : [dr0] "+r"(dr0),
[dr1] "+r"(dr1), [dr1] "+r"(dr1),
[dr2] "+r"(dr2), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [dr_out] "+r"(dr_out),
[remain] "+r"(cnt_remain),
[cnt_num] "+r"(cnt_num) [cnt_num] "+r"(cnt_num)
: :
: "cc", : "cc",
...@@ -2652,19 +2645,37 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2652,19 +2645,37 @@ void pooling3x3s2p0_max(const float* din,
"q9", "q9",
"q10", "q10",
"q11"); "q11");
#endif
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
}
for (int w = 0; w < remain; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (right) { if (right) {
int wstart = (w_unroll_size * 4 + remain) * S; float tmp = dr0[0]; // std::numeric_limits<float>::min();
int wend = std::min(wstart + K, win); for (int i = 0; i < wend; i++) {
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i])); tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]); tmp = std::max(tmp, dr2[i]);
} }
*(dr_out++) = tmp; *(dr_out++) = tmp;
} }
#endif
}
r0 = r2; r0 = r2;
r1 = r0 + win; r1 = r0 + win;
r2 = r1 + win; r2 = r1 + win;
...@@ -2815,6 +2826,7 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2815,6 +2826,7 @@ void pooling3x3s2p0_avg(const float* din,
dr2 -= 8; dr2 -= 8;
} }
// deal with right pad // deal with right pad
w_unroll_size = w_unroll_size < 0 ? 0 : w_unroll_size;
int wstart = w_unroll_size * 4 * S - P; int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) { for (int j = 0; j < w_unroll_remian; ++j) {
int wend = wstart + K; // std::min(wstart + K, win); int wend = wstart + K; // std::min(wstart + K, win);
......
...@@ -57,6 +57,8 @@ void PoolCompute::Run() { ...@@ -57,6 +57,8 @@ void PoolCompute::Run() {
(ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less; (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less;
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
bool win_ksize = (in_dims[2] > ksize[0]) && (in_dims[3] > ksize[1]);
kps_equal = kps_equal && win_ksize;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
......
...@@ -435,7 +435,7 @@ TEST(TestPoolRand, test_pool_rand) { ...@@ -435,7 +435,7 @@ TEST(TestPoolRand, test_pool_rand) {
adaptive, adaptive,
use_quantizer, use_quantizer,
pooling_type, pooling_type,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册