提交 2bbf01d1 编写于 作者: H hjchen2

Fix depthwise conv5x5 bug for padding 2

上级 a07503a7
...@@ -59,12 +59,11 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) { ...@@ -59,12 +59,11 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const float *input = input_data + offset; const float *input = input_data + offset;
const float bias = bias_data[j]; const float bias = bias_data[j];
float *output = output_data + offset; float *output = output_data + offset;
int remain = elementwise_num;
#if defined(__ARM_NEON__) || defined(__ARM_NEON) #if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = elementwise_num >> 0x4; int loop = elementwise_num >> 0x4;
remain = elementwise_num & 0xF; int remain = elementwise_num & 0xF;
float32x4_t rb = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k) { for (int k = 0; k < loop; ++k) {
float32x4_t rb = vdupq_n_f32(bias);
float32x4_t r0 = vld1q_f32(input); float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4); float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8); float32x4_t r2 = vld1q_f32(input + 8);
...@@ -80,10 +79,46 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) { ...@@ -80,10 +79,46 @@ inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
input += 16; input += 16;
output += 16; output += 16;
} }
#endif if (remain >= 8) {
for (int k = 0; k < remain; ++k) { float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
r0 = vaddq_f32(r0, rb);
r1 = vaddq_f32(r1, rb);
vst1q_f32(output, r0);
vst1q_f32(output + 4, r1);
input += 8;
output += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
vst1q_f32(output, r0);
input += 4;
output += 4;
remain -= 4;
}
if (remain > 0) {
float32x4_t r0 = vld1q_f32(input);
r0 = vaddq_f32(r0, rb);
switch (remain) {
case 1:
vst1q_lane_f32(output, r0, 0);
break;
case 2:
vst1_f32(output, vget_low_f32(r0));
break;
case 3:
vst1_f32(output, vget_low_f32(r0));
vst1_lane_f32(output, vget_high_f32(r0), 0);
break;
}
}
#else
for (int k = 0; k < elementwise_num; ++k) {
output[k] = input[k] + bias; output[k] = input[k] + bias;
} }
#endif // __ARM_NEON__
} }
} }
} }
......
...@@ -160,11 +160,8 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -160,11 +160,8 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
int valid_w_start = padding_w; int valid_w_start = padding_w;
int valid_w_end = output_w - valid_w_start; int valid_w_end = output_w - valid_w_start;
int valid_w = valid_w_end - valid_w_start; int valid_w = valid_w_end - valid_w_start;
DLOG << "valid_h_start: " << valid_h_start;
DLOG << "valid_h_end: " << valid_h_end;
DLOG << "valid_w_start: " << valid_w_start;
DLOG << "valid_w_end: " << valid_w_end;
#pragma omp parallel for
for (int g = 0; g < input.dims()[1]; ++g) { for (int g = 0; g < input.dims()[1]; ++g) {
const float *input_ptr = input_data + g * image_size; const float *input_ptr = input_data + g * image_size;
const float *filter_ptr = filter_data + g * 25; const float *filter_ptr = filter_data + g * 25;
...@@ -214,25 +211,26 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -214,25 +211,26 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
float32x4_t row4 = vld1q_f32(input_ptr4); float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t row5 = vld1q_f32(input_ptr5); float32x4_t row5 = vld1q_f32(input_ptr5);
float32x4_t zero = vdupq_n_f32(0.f); float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc0, acc1;
for (int w = valid_w_start - 1; w >= 0; --w) { for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w; int padding = padding_w - w;
if (padding >= 5) { if (padding >= 5) {
output_ptr0[w] = 0.f; output_ptr0[w] = 0.f;
output_ptr1[w] = 0.f; output_ptr1[w] = 0.f;
} else { } else {
row0 = vmulq_f32(row0, _ker[0]); acc0 = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]); acc0 = vmlaq_f32(acc0, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]); acc0 = vmlaq_f32(acc0, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]); acc0 = vmlaq_f32(acc0, row4, _ker[4]);
row1 = vmulq_f32(row1, _ker[0]); acc1 = vmulq_f32(row1, _ker[0]);
row1 = vmlaq_f32(row1, row2, _ker[1]); acc1 = vmlaq_f32(acc1, row2, _ker[1]);
row1 = vmlaq_f32(row1, row3, _ker[2]); acc1 = vmlaq_f32(acc1, row3, _ker[2]);
row1 = vmlaq_f32(row1, row4, _ker[3]); acc1 = vmlaq_f32(acc1, row4, _ker[3]);
row1 = vmlaq_f32(row1, row5, _ker[4]); acc1 = vmlaq_f32(acc1, row5, _ker[4]);
row0 = vpaddq_f32(row0, row1); acc0 = vpaddq_f32(acc0, acc1);
float32x2_t sum = float32x2_t sum =
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0)); vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
vst1_lane_f32(output_ptr0 + w, sum, 0); vst1_lane_f32(output_ptr0 + w, sum, 0);
vst1_lane_f32(output_ptr1 + w, sum, 1); vst1_lane_f32(output_ptr1 + w, sum, 1);
...@@ -456,6 +454,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -456,6 +454,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
float32x4_t row4 = vld1q_f32(input_ptr4); float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t row5 = vld1q_f32(input_ptr5); float32x4_t row5 = vld1q_f32(input_ptr5);
float32x4_t zero = vdupq_n_f32(0.f); float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc0, acc1;
for (int w = valid_w_end; w < output_w; ++w) { for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w); int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) { if (padding >= 5) {
...@@ -479,19 +478,19 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -479,19 +478,19 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
row3 = vextq_f32(row3, zero, 1); row3 = vextq_f32(row3, zero, 1);
row4 = vextq_f32(row4, zero, 1); row4 = vextq_f32(row4, zero, 1);
row5 = vextq_f32(row5, zero, 1); row5 = vextq_f32(row5, zero, 1);
row0 = vmulq_f32(row0, _ker[0]); acc0 = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]); acc0 = vmlaq_f32(acc0, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]); acc0 = vmlaq_f32(acc0, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]); acc0 = vmlaq_f32(acc0, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]); acc0 = vmlaq_f32(acc0, row4, _ker[4]);
row1 = vmulq_f32(row1, _ker[0]); acc1 = vmulq_f32(row1, _ker[0]);
row1 = vmlaq_f32(row1, row2, _ker[1]); acc1 = vmlaq_f32(acc1, row2, _ker[1]);
row1 = vmlaq_f32(row1, row3, _ker[2]); acc1 = vmlaq_f32(acc1, row3, _ker[2]);
row1 = vmlaq_f32(row1, row4, _ker[3]); acc1 = vmlaq_f32(acc1, row4, _ker[3]);
row1 = vmlaq_f32(row1, row5, _ker[4]); acc1 = vmlaq_f32(acc1, row5, _ker[4]);
row0 = vpaddq_f32(row0, row1); acc0 = vpaddq_f32(acc0, acc1);
float32x2_t sum = float32x2_t sum =
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0)); vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0));
sum0 += vget_lane_f32(sum, 0); sum0 += vget_lane_f32(sum, 0);
sum1 += vget_lane_f32(sum, 1); sum1 += vget_lane_f32(sum, 1);
*output_ptr0 = sum0; *output_ptr0 = sum0;
...@@ -519,18 +518,18 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -519,18 +518,18 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
float32x4_t row3 = vld1q_f32(input_ptr3); float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4); float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t zero = vdupq_n_f32(0.f); float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc;
for (int w = valid_w_start - 1; w >= 0; --w) { for (int w = valid_w_start - 1; w >= 0; --w) {
int padding = padding_w - w; int padding = padding_w - w;
if (padding >= 5) { if (padding >= 5) {
output_ptr0[w] = 0.f; output_ptr0[w] = 0.f;
} else { } else {
row0 = vmulq_f32(row0, _ker[0]); acc = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]); acc = vmlaq_f32(acc, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]); acc = vmlaq_f32(acc, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]); acc = vmlaq_f32(acc, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]); acc = vmlaq_f32(acc, row4, _ker[4]);
float32x2_t sum = float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc));
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0));
sum = vpadd_f32(sum, sum); sum = vpadd_f32(sum, sum);
vst1_lane_f32(output_ptr0 + w, sum, 0); vst1_lane_f32(output_ptr0 + w, sum, 0);
...@@ -687,6 +686,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -687,6 +686,7 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
float32x4_t row3 = vld1q_f32(input_ptr3); float32x4_t row3 = vld1q_f32(input_ptr3);
float32x4_t row4 = vld1q_f32(input_ptr4); float32x4_t row4 = vld1q_f32(input_ptr4);
float32x4_t zero = vdupq_n_f32(0.f); float32x4_t zero = vdupq_n_f32(0.f);
float32x4_t acc;
for (int w = valid_w_end; w < output_w; ++w) { for (int w = valid_w_end; w < output_w; ++w) {
int padding = w + 5 - (padding_w + input_w); int padding = w + 5 - (padding_w + input_w);
if (padding >= 5) { if (padding >= 5) {
...@@ -703,13 +703,12 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input, ...@@ -703,13 +703,12 @@ void DepthwiseConv5x5S1<float, float>(const framework::Tensor &input,
row2 = vextq_f32(row2, zero, 1); row2 = vextq_f32(row2, zero, 1);
row3 = vextq_f32(row3, zero, 1); row3 = vextq_f32(row3, zero, 1);
row4 = vextq_f32(row4, zero, 1); row4 = vextq_f32(row4, zero, 1);
row0 = vmulq_f32(row0, _ker[0]); acc = vmulq_f32(row0, _ker[0]);
row0 = vmlaq_f32(row0, row1, _ker[1]); acc = vmlaq_f32(acc, row1, _ker[1]);
row0 = vmlaq_f32(row0, row2, _ker[2]); acc = vmlaq_f32(acc, row2, _ker[2]);
row0 = vmlaq_f32(row0, row3, _ker[3]); acc = vmlaq_f32(acc, row3, _ker[3]);
row0 = vmlaq_f32(row0, row4, _ker[4]); acc = vmlaq_f32(acc, row4, _ker[4]);
float32x2_t sum = float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_high_f32(acc));
vpadd_f32(vget_low_f32(row0), vget_high_f32(row0));
sum = vpadd_f32(sum, sum); sum = vpadd_f32(sum, sum);
sum0 += vget_lane_f32(sum, 0); sum0 += vget_lane_f32(sum, 0);
*output_ptr0 = sum0; *output_ptr0 = sum0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册