未验证 提交 8448ad6c 编写于 作者: H HappyAngel 提交者: GitHub

[arm] fix 3x3s2p0 max in un-equal padding compute error (#4122)

* fix pooling bug and speed

* fix build error

* delete VLOGin pool, test=develop

* add openmp, test=develop

* fix lite/kernels/arm/pool_compute_test basic_pooling compute error bug, test=develop

* update pooling 2-pad to 4-pad, test=develop

* fix 2-pad to 4-pad in operators/pool_op.h, AttachKernel will set param, so 2-pad to 4-pad funcs should put in AttachKernel. test=ddevellop

* put 2-pad to 4-pad in AttachImpl, test=develop

* according to reviews, fix some format error. test=develop

* fix format errorr, add (). test=develop

* change paddings type to support dynamically modify, test=develop

* update padding type int other devices, test=develop

* fix x8d build error on shared_ptr, test=ddevelop

* fix formmat in operators pool_op.cc, test=develop

* fix conflict

* fix conv_dw

* add relu

* fix build

* fix build

* fix compute

* fix compute error

* fix conv3x3 compute error

* fix conv3x3s2

* fix conv3x3s2  kl

* fix format, test=develop

* add some op infershape implement, test=develop

* add reshape infershape, test=develop

* fix format, test=develop

* fix format, test=develop

* fix space format. test=develop

* add conv_transpose+bn fusion. test=develop

* delete note, test=develop

* fix format, test=develop

* fix format space, test=develop

* fix opt run error, test=develop

* add boxcoder opencl kernel, test=develop

* fix format, test=develop

* add cmake, test=develop

* fix format. test=develop

* fix format. test=develop

* fix format aa. test=develop

* fix , test=develop

* update profile info(add new element), test=develop

* fix clang ut build error

* add gemm+relu6

* fix build error

* fix .h

* fix gemm_s8

* fix ut conv+leakyRelu

* improve 3x3s1 direct profile

* fix format, test=develop

* add gemv+relu6/lleakyRelu

* fix v7 build bug

* fix relu6 bug

* fix gemm ut bug

* fix ut

* fix ut

* fi format. test=develop

* fix format. test=develop

* fic format. test=develop

* ff. test=develop

* fix v7 clang build error, test=develop

* fix v7 build register error, test=develop

* fix format.  test=develop

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop

* fix test=develop

* fix conflict, test=develop

* rm other info, test=develop

* rm other info, test=develop

* fix build register error, ttest=develop

* fix format, test=develop

* ff format,test=develop

* fix relu6 problem, test=develop

* fix form, test=develop

* fix format, test=develop

* ff, test=develop

* add deformable conv op

* add six / scale , test=develop

* fix 1x1 deformable conv

* add convparam to deformabelconv param.

* fix other conv kernel size

* delete printf info

* fix format, test=develop

* fix formatt. test=develop

* delete exttra info, test=develop

* test=develop

* ff, test=develop

* fix ut error. test=develop

* fix ut, test=develop

* fix format. test=develop

* test=develop

* fix pooling overflow, test=develop

* fix conflict test=develop

* delete unuseful message. test=develop

* add grouup_norm

* fix format. test=develop

* fix foormat, test=develop

* fix format. test=develop

* fix ff.test=develoop

* fix xiaodu crash. test=develop

* format. test=develop

* fix concatt axis < 0 errorr,ttest=develop

* fix format. test=develop

* fix conv int8 kernel choose and sooftmax compute bug

* change axis_size = 4 kernel choose, test=develop

* fix format. test=develop

* uupdate sequence_pool and sequence_conv profiler, test=develop

* uupdate sequence_pool and sequence_conv profiler, test=develop

* fix format, testt=develop

* fix format, test=develop

* fix format test=develop

* fix compute error. test=develop

* dd

* fix compute error

* fix compute error, test=develop

* delete warning and extra info, test=develop

* update sequence_conv profile

* delete extra file, test=develop

* delete extra file test=develop

* fix format test=develop

* d

* delete test=develop

* test=develop

* add elu act

* update build

* fix elu act not find error, test=develop

* fix format.test=develop

* test=develop

* add sequence_pool_grad op on arm

* test=develop

* fix build error

* fix int8 model opt error iin conv+conv fusion, test=develop

* fix format. test=develop

* fiix build error, test=develop

* fix format, test=develop

* fix build,test=develop

* test=develop

* fix arm winograd compute segment. test=develop

* fix ttfnet bug. test=develop

* fix format. test=develop

* fix format. test=develop

* fix compute error

* fix format, test=develop

* fix 3x3s2p0 max compute error when input_padding is not equal. test=develop

* fix format test=develop

* fix format test=develop

* fix compute error. test=develop

* test=develop

* test=develop

* fix compute. test=develop

* add coonv+conv fusion requirment. test=develop

* test=develop

* fix format. test=develop

* refresh coonv+conv. test=develop
上级 a6a65a52
...@@ -1132,6 +1132,11 @@ void pooling2x2s2p0_max(const float* din, ...@@ -1132,6 +1132,11 @@ void pooling2x2s2p0_max(const float* din,
float* dr_out = data_out_channel; float* dr_out = data_out_channel;
auto dr0 = r0; auto dr0 = r0;
auto dr1 = r1; auto dr1 = r1;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0.f, sizeof(float) * wout);
data_out_channel += wout;
continue;
}
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = r0; dr1 = r0;
} }
...@@ -1164,7 +1169,7 @@ void pooling2x2s2p0_max(const float* din, ...@@ -1164,7 +1169,7 @@ void pooling2x2s2p0_max(const float* din,
int wstart = 0; int wstart = 0;
for (int j = 0; j < w_unroll_remian; ++j) { for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem); int wend = std::min(wstart + K, rem);
float tmp = dr0[wstart]; float tmp = wstart < rem ? dr0[wstart] : 0.f;
for (int i = wstart; i < wend; i++) { for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, dr0[i]); tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]); tmp = std::max(tmp, dr1[i]);
...@@ -1222,9 +1227,20 @@ void pooling2x2s2p0_avg(const float* din, ...@@ -1222,9 +1227,20 @@ void pooling2x2s2p0_avg(const float* din,
float* dr_out = data_out_channel; float* dr_out = data_out_channel;
auto dr0 = r0; auto dr0 = r0;
auto dr1 = r1; auto dr1 = r1;
if (h * S + K - P > hin + 1) {
memset(dr_out, 0.f, sizeof(float) * wout);
data_out_channel += wout;
continue;
}
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = zero_ptr; dr1 = zero_ptr;
if (exclusive) {
vcoef = vdupq_n_f32(0.5f); vcoef = vdupq_n_f32(0.5f);
} else {
if (pad_bottom == 0) {
vcoef = vdupq_n_f32(0.5f);
}
}
} }
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
...@@ -1257,12 +1273,21 @@ void pooling2x2s2p0_avg(const float* din, ...@@ -1257,12 +1273,21 @@ void pooling2x2s2p0_avg(const float* din,
int wend = std::min(wstart + K, rem); int wend = std::min(wstart + K, rem);
float coef = 0.25f; float coef = 0.25f;
float tmp = 0.f; float tmp = 0.f;
if (exclusive) {
if (wend - wstart == 1) {
coef *= 2;
}
if (h * S + K - P > hin) {
coef *= 2;
}
} else {
if (wend - wstart == 1 && pad_right == 0) { if (wend - wstart == 1 && pad_right == 0) {
coef *= 2; coef *= 2;
} }
if (h * S + K - P > hin && pad_bottom == 0) { if (h * S + K - P > hin && pad_bottom == 0) {
coef *= 2; coef *= 2;
} }
}
for (int i = wstart; i < wend; i++) { for (int i = wstart; i < wend; i++) {
tmp += dr0[i] + dr1[i]; tmp += dr0[i] + dr1[i];
} }
...@@ -1442,7 +1467,7 @@ void pooling2x2s2p1_avg(const float* din, ...@@ -1442,7 +1467,7 @@ void pooling2x2s2p1_avg(const float* din,
} }
if (h * S + K - P > hin) { if (h * S + K - P > hin) {
dr1 = zero_ptr; dr1 = zero_ptr;
if (exclusive) { if (exclusive || pad_bottom == 0) {
coef_h = 1.f; coef_h = 1.f;
} }
if (h * S + K - P > hin + 1) { if (h * S + K - P > hin + 1) {
...@@ -1490,9 +1515,15 @@ void pooling2x2s2p1_avg(const float* din, ...@@ -1490,9 +1515,15 @@ void pooling2x2s2p1_avg(const float* din,
int st = wstart > 0 ? wstart : 0; int st = wstart > 0 ? wstart : 0;
float tmp = 0.f; float tmp = 0.f;
float coef = coef_h / 2; float coef = coef_h / 2;
if (exclusive && wend - st == 1) { if (exclusive) {
if (wend - st == 1) {
coef = coef_h;
}
} else {
if (wend - st == 1 && wstart > 0 && pad_right == 0) {
coef = coef_h; coef = coef_h;
} }
}
for (int i = 0; i < wend - st; i++) { for (int i = 0; i < wend - st; i++) {
tmp += dr0[i] + dr1[i]; tmp += dr0[i] + dr1[i];
} }
...@@ -2193,13 +2224,7 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2193,13 +2224,7 @@ void pooling3x3s2p1_max(const float* din,
w_unroll_size -= 1; w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4; w_unroll_remian = wout - w_unroll_size * 4;
} }
int w_needed = wout * 2 + 1; float32x4_t vmin = vdupq_n_f32(std::numeric_limits<float>::lowest());
int pad_right_ = w_needed - win - pad_bottom;
int w_2 = pad_right_ > 0 ? w_unroll_remian : w_unroll_remian + 1;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmin = vdupq_n_f32(minval);
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
...@@ -2238,11 +2263,6 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2238,11 +2263,6 @@ void pooling3x3s2p1_max(const float* din,
break; break;
} }
} }
auto pr0 = dr0;
auto pr1 = dr1;
auto pr2 = dr2;
int cnt_num = w_unroll_size; int cnt_num = w_unroll_size;
if (w_unroll_size > 0) { if (w_unroll_size > 0) {
#ifdef __aarch64__ #ifdef __aarch64__
...@@ -2296,53 +2316,27 @@ void pooling3x3s2p1_max(const float* din, ...@@ -2296,53 +2316,27 @@ void pooling3x3s2p1_max(const float* din,
"q11", "q11",
"q15"); "q15");
#endif #endif
dr0 -= 8; dr0 -= 8;
dr1 -= 8; dr1 -= 8;
dr2 -= 8; dr2 -= 8;
} else {
float tmp = minval;
for (int i = 0; i < 2; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]);
}
dr_out[0] = tmp;
dr0++;
dr1++;
dr2++;
dr_out++;
}
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
} }
// deal with right pad
if (pad_right_) { int wstart = w_unroll_size * 4 * S - P;
float tmp = minval; for (int j = 0; j < w_unroll_remian; ++j) {
for (int i = 1; i < 3; i++) { int wend = std::min(wstart + K, win);
tmp = std::max(tmp, std::max(pr0[win - i], pr1[win - i])); int st = wstart > 0 ? wstart : 0;
tmp = std::max(tmp, pr2[win - i]); float tmp = dr0[0];
for (int i = 0; i < wend - st; i++) {
tmp = std::max(tmp, dr0[i]);
tmp = std::max(tmp, dr1[i]);
tmp = std::max(tmp, dr2[i]);
} }
dr_out[0] = tmp; *(dr_out++) = tmp;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
dr2 += S - (st - wstart);
wstart += S;
} }
data_out_channel += wout; data_out_channel += wout;
} }
} }
...@@ -2575,11 +2569,10 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2575,11 +2569,10 @@ void pooling3x3s2p0_max(const float* din,
int remain = w_unroll_remian - 1; int remain = w_unroll_remian - 1;
int right = wout * 2 + 1 - win; // if need right pad int right = wout * 2 + 1 - win; // if need right pad
int tmp_val = (w_unroll_size * 4 + remain) * S;
int w_2 = right > 0 ? w_unroll_remian : w_unroll_remian + 1; int wend = std::min(tmp_val + K, win) - tmp_val;
w_2 = w_unroll_size <= 0 ? w_2 - 1 : w_2;
float minval = std::numeric_limits<float>::lowest(); float minval = std::numeric_limits<float>::lowest();
remain = right > 0 ? remain : remain + 1;
for (int n = 0; n < num; ++n) { for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out; float* data_out_batch = data_out + n * chout * size_channel_out;
const float* data_in_batch = data_in + n * chin * size_channel_in; const float* data_in_batch = data_in + n * chin * size_channel_in;
...@@ -2630,59 +2623,12 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2630,59 +2623,12 @@ void pooling3x3s2p0_max(const float* din,
"v9", "v9",
"v10", "v10",
"v11"); "v11");
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
for (int w = 0; w < w_2 - 1; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
#else #else
asm volatile( asm volatile(P3x3S2P0_INIT P3x3S2P0_MAX
P3x3S2P0_INIT P3x3S2P0_MAX
"cmp %[remain], #0 @cmp cnt_num\n"
"sub %[dr0], #32 @sub - 8\n"
"sub %[dr1], #32 @sub - 8\n"
"sub %[dr2], #32 @sub - 8\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load \n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load \n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load \n"
"vmov.f32 s3,s2 @mov \n"
"vmov.f32 s7,s6 @mov \n"
"vmov.f32 s11,s10 @mov \n"
"vmax.f32 q0, q0, q1 @max n"
"sub %[dr0], #8 @add w \n"
"sub %[dr1], #8 @add w \n"
"sub %[dr2], #8 @add w \n"
"vmax.f32 q0, q0, q2 @max \n"
"vpmax.f32 d0, d0, d1 @pmax \n"
"vpmax.f32 d0, d0, d0 @pmax \n"
"subs %[remain], #1 @subs \n"
"vst1.f32 d0[0], [%[dr_out]]! @vst \n"
"bne 2b @bne \n"
"4: @exit\n"
: [dr0] "+r"(dr0), : [dr0] "+r"(dr0),
[dr1] "+r"(dr1), [dr1] "+r"(dr1),
[dr2] "+r"(dr2), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [dr_out] "+r"(dr_out),
[remain] "+r"(cnt_remain),
[cnt_num] "+r"(cnt_num) [cnt_num] "+r"(cnt_num)
: :
: "cc", : "cc",
...@@ -2699,19 +2645,37 @@ void pooling3x3s2p0_max(const float* din, ...@@ -2699,19 +2645,37 @@ void pooling3x3s2p0_max(const float* din,
"q9", "q9",
"q10", "q10",
"q11"); "q11");
#endif
dr0 -= 8;
dr1 -= 8;
dr2 -= 8;
}
for (int w = 0; w < remain; w += 1) {
float32x4_t vr0 = vld1q_f32(dr0);
float32x4_t vr1 = vld1q_f32(dr1);
float32x4_t vr2 = vld1q_f32(dr2);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dr_out[0] = vget_lane_f32(vmax, 0);
dr_out++;
dr0 += 2;
dr1 += 2;
dr2 += 2;
}
if (right) { if (right) {
int wstart = (w_unroll_size * 4 + remain) * S; float tmp = dr0[0]; // std::numeric_limits<float>::min();
int wend = std::min(wstart + K, win); for (int i = 0; i < wend; i++) {
float tmp = dr0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(dr0[i], dr1[i])); tmp = std::max(tmp, std::max(dr0[i], dr1[i]));
tmp = std::max(tmp, dr2[i]); tmp = std::max(tmp, dr2[i]);
} }
*(dr_out++) = tmp; *(dr_out++) = tmp;
} }
#endif
}
r0 = r2; r0 = r2;
r1 = r0 + win; r1 = r0 + win;
r2 = r1 + win; r2 = r1 + win;
...@@ -2748,6 +2712,7 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2748,6 +2712,7 @@ void pooling3x3s2p0_avg(const float* din,
w_unroll_size -= 1; w_unroll_size -= 1;
w_unroll_remian = wout - w_unroll_size * 4; w_unroll_remian = wout - w_unroll_size * 4;
} }
// do overflow process // do overflow process
w_unroll_size -= 1; w_unroll_size -= 1;
w_unroll_remian += 4; w_unroll_remian += 4;
...@@ -2861,6 +2826,7 @@ void pooling3x3s2p0_avg(const float* din, ...@@ -2861,6 +2826,7 @@ void pooling3x3s2p0_avg(const float* din,
dr2 -= 8; dr2 -= 8;
} }
// deal with right pad // deal with right pad
w_unroll_size = w_unroll_size < 0 ? 0 : w_unroll_size;
int wstart = w_unroll_size * 4 * S - P; int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) { for (int j = 0; j < w_unroll_remian; ++j) {
int wend = wstart + K; // std::min(wstart + K, win); int wend = wstart + K; // std::min(wstart + K, win);
......
...@@ -108,10 +108,12 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { ...@@ -108,10 +108,12 @@ void ConvConvFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
if (!(kw == 1 && kh == 1)) { if (!(kw == 1 && kh == 1)) {
LOG(FATAL) << "The kernel size of the second conv must be 1x1"; LOG(FATAL) << "The kernel size of the second conv must be 1x1";
} }
auto channel0_out = weight0_t->dims()[0];
auto channel1_in = weight1_t->dims()[1] * groups1;
CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same"; CHECK_EQ(enable0_int8, enable1_int8) << "The Conv compute type must be same";
CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1"; CHECK_EQ(groups1, 1) << "The groups of weight1_dim must be 1";
CHECK_EQ(weight0_t->dims()[0], weight1_t->dims()[1]) CHECK_EQ(channel0_out, channel1_in) << "channel0_out == channel1_in";
<< "weight0_dims[0] == weight1_dim[1]";
for (int i = 0; i < strides1.size(); i++) { for (int i = 0; i < strides1.size(); i++) {
CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i] CHECK_EQ(strides1[i], 1) << "strides[" << i << "]: " << strides1[i]
<< " must be 1"; << " must be 1";
......
...@@ -57,7 +57,9 @@ void PoolCompute::Run() { ...@@ -57,7 +57,9 @@ void PoolCompute::Run() {
(ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less; (ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less;
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
bool win_ksize = (in_dims[2] > ksize[0]) && (in_dims[3] > ksize[1]);
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
kps_equal = kps_equal && win_ksize;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
......
...@@ -435,7 +435,7 @@ TEST(TestPoolRand, test_pool_rand) { ...@@ -435,7 +435,7 @@ TEST(TestPoolRand, test_pool_rand) {
adaptive, adaptive,
use_quantizer, use_quantizer,
pooling_type, pooling_type,
{1, 2, 4}, {4},
{FLAGS_power_mode}); {FLAGS_power_mode});
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册