From 7eaf8451e9e81e6a4468182a773cd9e951b734e3 Mon Sep 17 00:00:00 2001 From: yangfei Date: Fri, 10 Aug 2018 20:15:17 +0800 Subject: [PATCH] implement multithreading 3x3 s1 depthwise_conv --- src/operators/math/depthwise_conv_3x3.cpp | 578 ++++++++++++++++------ 1 file changed, 413 insertions(+), 165 deletions(-) diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index 7e353c29b8..a67a24a4de 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -529,42 +529,42 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, const float *newscale_data = new_scale->data(); const float *newbias_data = new_bias->data(); - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); - const int l = h; - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int hxw = h * w; + const int input_channel = static_cast(input->dims()[1]); + + const int input_height = static_cast(input->dims()[2]); + const int input_width = static_cast(input->dims()[3]); + const int output_height = static_cast(output->dims()[2]); + const int output_width = static_cast(output->dims()[3]); + + const int hxw = input_height * input_width; + + const int l = input_height; float32x4_t vnewbias = vdupq_n_f32(0.0); float32x4_t vnewscale = vdupq_n_f32(1.0); float32x4_t vzero = vdupq_n_f32(0); - for (int b = 0; b < batch_size; ++b) { - const float *filter_data_tmp = filter_data; - - for (int j = 0; j < c; ++j) { - vnewbias = vdupq_n_f32(newbias_data[j]); - vnewscale = vdupq_n_f32(newscale_data[j]); - - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; + for (int b = 0; b < batch_size; b++) { + filter_data = filter->data(); + for (int c = 0; c < input_channel; c++) { + vnewbias = vdupq_n_f32(newbias_data[c]); + vnewscale = vdupq_n_f32(newscale_data[c]); + + float w00 = filter_data[0]; + float w01 = filter_data[1]; + float w02 = filter_data[2]; + float w10 = filter_data[3]; + float w11 = filter_data[4]; + float w12 = filter_data[5]; + float w20 = filter_data[6]; + float w21 = filter_data[7]; + float w22 = filter_data[8]; output_data[0] = w11 * input_data[0] + w12 * input_data[1] + w21 * input_data[l] + w22 * input_data[l + 1]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + w20 * input_data[2 * l - 2] + w21 * input_data[2 * l - 1]; - output_data[(l - 1) * l] = w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; @@ -572,13 +572,13 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, w01 * input_data[(l - 2) * (l + 1) + 1] + w10 * input_data[l * l - 2] + w11 * input_data[l * l - 1]; - output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; + output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c]; output_data[l - 1] = - output_data[l - 1] * newscale_data[j] + newbias_data[j]; + output_data[l - 1] * newscale_data[c] + newbias_data[c]; output_data[(l - 1) * l] = - output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; + output_data[(l - 1) * l] * newscale_data[c] + newbias_data[c]; output_data[l * l - 1] = - output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; + output_data[l * l - 1] * newscale_data[c] + newbias_data[c]; if (if_relu) { output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; @@ -593,6 +593,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + w01 * input_data[i * l + l - 1 - l] + w10 * input_data[i * l + l - 1 - 1] + @@ -600,9 +601,9 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, w20 * input_data[i * l + l - 1 + l - 1] + w21 * input_data[i * l + l - 1 + l]; output_data[i * l] = - output_data[i * l] * newscale_data[j] + newbias_data[j]; + output_data[i * l] * newscale_data[c] + newbias_data[c]; output_data[i * l + l - 1] = - output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; + output_data[i * l + l - 1] * newscale_data[c] + newbias_data[c]; if (if_relu) { output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * l]; @@ -611,28 +612,19 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, } } - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, - tmp3, tmp4, tmp5, out0; - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + l); - const float *input_tmp_end = input_tmp + (l - 2) * l; - in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + l); - int c_mid = l_mid; - auto output_ptr = output_data + 1; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + l + 4); + int m; + for (m = 1; m < output_width - 4; m += 4) { + float *output_ptr = output_data + m; + float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; + in0 = vld1q_f32(input_data + m - 1); + in1 = vld1q_f32(input_data + m + 3); + in2 = vld1q_f32(input_data + input_width + m - 1); + in3 = vld1q_f32(input_data + input_width + m + 3); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); tmp3 = vextq_f32(in2, in3, 2); - out0 = vmulq_n_f32(in0, w10); out0 = vmlaq_n_f32(out0, tmp0, w11); out0 = vmlaq_n_f32(out0, tmp1, w12); @@ -644,182 +636,438 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, out0 = vmaxq_f32(out0, vzero); } vst1q_f32(output_ptr, out0); + } + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + } + for (int j = m; j < output_width - 1; j++) { + output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 + + input_data[j + 1] * w12 + + input_data[input_width + j - 1] * w20 + + input_data[input_width + j] * w21 + + input_data[input_width + j + 1] * w22; + output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c]; - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); + if (if_relu) { + output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; + } + } - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + float *output_ptr = + output_data + (output_height - 1) * output_width + m; - out0 = vmulq_n_f32(in4, w00); + float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; + in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1); + in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3); + in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1); + in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3); + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + out0 = vmulq_n_f32(in0, w00); out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); + out0 = vmlaq_n_f32(out0, in2, w10); out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (if_relu) { out0 = vmaxq_f32(out0, vzero); } - vst1q_f32(output_ptr + (l - 1) * l, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; + vst1q_f32(output_ptr, out0); } + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + } + for (int j = m; j < output_width - 1; j++) { + output_data[(output_height - 1) * input_width + j] = + input_data[(output_height - 2) * input_width + j - 1] * w00 + + input_data[(output_height - 2) * input_width + j] * w01 + + input_data[(output_height - 2) * input_width + j + 1] * w02 + + input_data[(output_height - 1) * input_width + j - 1] * w10 + + input_data[(output_height - 1) * input_width + j] * w11 + + input_data[(output_height - 1) * input_width + j + 1] * w12; + output_data[(output_height - 1) * output_width + j] = + output_data[(output_height - 1) * output_width + j] * + newscale_data[c] + + newbias_data[c]; - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); + if (if_relu) { + output_data[(output_height - 1) * output_width + j] = + output_data[(output_height - 1) * output_width + j] < 0 + ? 0 + : output_data[(output_height - 1) * output_width + j]; + } } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); + #pragma omp parallel for + for (int i = 1; i < output_height - 1; i++) { + for (int m = 1; (m + 3) < output_width - 1; m = m + 4) { + float *output_ptr = output_data + i * output_width + m; + float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, + tmp4, tmp5, out0; + in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1); + in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3); + in2 = vld1q_f32(input_data + i * input_width + m - 1); + in3 = vld1q_f32(input_data + i * input_width + m + 3); + in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1); + in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3); + + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + tmp4 = vextq_f32(in4, in5, 1); + tmp5 = vextq_f32(in4, in5, 2); + + out0 = vmulq_n_f32(in0, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); + int m; + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); + + for (int j = m; j < output_width - 1; j++) { + output_data[i * output_width + j] = + input_data[(i - 1) * input_width + j - 1] * w00 + + input_data[(i - 1) * input_width + j] * w01 + + input_data[(i - 1) * input_width + j + 1] * w02 + + input_data[(i)*input_width + j - 1] * w10 + + input_data[(i)*input_width + j] * w11 + + input_data[(i)*input_width + j + 1] * w12 + + input_data[(i + 1) * input_width + j - 1] * w20 + + input_data[(i + 1) * input_width + j] * w21 + + input_data[(i + 1) * input_width + j + 1] * w22; + output_data[i * output_width + j] = + newscale_data[c] * output_data[i * output_width + j] + + newbias_data[c]; + if (if_relu) { + output_data[i * output_width + j] = + output_data[i * output_width + j] < 0 + ? 0 + : output_data[i * output_width + j]; + } } } - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + input_data = input_data + hxw; + output_data = output_data + hxw; + filter_data = filter_data + 9; + } + } - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); + /* + const float *input_data = input->data(); + const float *filter_data = filter->data(); + float *output_data = output->data(); + const float *newscale_data = new_scale->data(); + const float *newbias_data = new_bias->data(); + + const int h = static_cast(input->dims()[2]); + const int w = static_cast(input->dims()[3]); + const int l = h; + + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const int hxw = h * w; + float32x4_t vnewbias = vdupq_n_f32(0.0); + float32x4_t vnewscale = vdupq_n_f32(1.0); + float32x4_t vzero = vdupq_n_f32(0); + + for (int b = 0; b < batch_size; ++b) { + const float *filter_data_tmp = filter_data; + + for (int j = 0; j < c; ++j) { + vnewbias = vdupq_n_f32(newbias_data[j]); + vnewscale = vdupq_n_f32(newscale_data[j]); + + int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + output_data[0] = w11 * input_data[0] + w12 * input_data[1] + + w21 * input_data[l] + w22 * input_data[l + 1]; + + output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + + w20 * input_data[2 * l - 2] + + w21 * input_data[2 * l - 1]; - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + output_data[(l - 1) * l] = + w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + + w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; + output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + + w01 * input_data[(l - 2) * (l + 1) + 1] + + w10 * input_data[l * l - 2] + + w11 * input_data[l * l - 1]; + output_data[0] = output_data[0] * newscale_data[j] + newbias_data[j]; + output_data[l - 1] = + output_data[l - 1] * newscale_data[j] + newbias_data[j]; + output_data[(l - 1) * l] = + output_data[(l - 1) * l] * newscale_data[j] + newbias_data[j]; + output_data[l * l - 1] = + output_data[l * l - 1] * newscale_data[j] + newbias_data[j]; + + if (if_relu) { + output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; + output_data[l - 1] = output_data[l - 1] < 0 ? 0 : output_data[l - 1]; + output_data[(l - 1) * l] = + output_data[(l - 1) * l] < 0 ? 0 : output_data[(l - 1) * l]; + output_data[l * l - 1] = + output_data[l * l - 1] < 0 ? 0 : output_data[l * l - 1]; } - if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + for (int i = 1; i < l - 1; ++i) { + output_data[i * l] = + w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + + w11 * input_data[i * l] + w12 * input_data[i * l + 1] + + w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; + output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + + w01 * input_data[i * l + l - 1 - l] + + w10 * input_data[i * l + l - 1 - 1] + + w11 * input_data[i * l + l - 1] + + w20 * input_data[i * l + l - 1 + l - 1] + + w21 * input_data[i * l + l - 1 + l]; + output_data[i * l] = + output_data[i * l] * newscale_data[j] + newbias_data[j]; + output_data[i * l + l - 1] = + output_data[i * l + l - 1] * newscale_data[j] + newbias_data[j]; + + if (if_relu) { + output_data[i * l] = output_data[i * l] < 0 ? 0 : output_data[i * + l]; output_data[i * l + l - 1] = + output_data[i * l + l - 1] < 0 ? 0 : output_data[i * l + l - 1]; + } } - } - // mid - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; + // top 1 row and bottom 1 row + const float *input_tmp = input_data; + + float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, + tmp3, tmp4, tmp5, out0; + in0 = vld1q_f32(input_tmp); + in2 = vld1q_f32(input_tmp + l); + const float *input_tmp_end = input_tmp + (l - 2) * l; + in4 = vld1q_f32(input_tmp_end); + in6 = vld1q_f32(input_tmp_end + l); + int c_mid = l_mid; + auto output_ptr = output_data + 1; for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + in1 = vld1q_f32(input_tmp + 4); + in3 = vld1q_f32(input_tmp + l + 4); - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); - out0 = vmulq_n_f32(in0_tmp, w00); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); + + in5 = vld1q_f32(input_tmp_end + 4); + in7 = vld1q_f32(input_tmp_end + l + 4); + + tmp0 = vextq_f32(in4, in5, 1); + tmp1 = vextq_f32(in4, in5, 2); + tmp2 = vextq_f32(in6, in7, 1); + tmp3 = vextq_f32(in6, in7, 2); + + out0 = vmulq_n_f32(in4, w00); out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, in6, w10); out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (if_relu) { out0 = vmaxq_f32(out0, vzero); } - vst1q_f32(output_ptr, out0); + vst1q_f32(output_ptr + (l - 1) * l, out0); - output_ptr += 4; + // can optimize to each 8 stride. input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; + input_tmp_end += 4; + output_ptr += 4; + in0 = in1; + in2 = in3; + in4 = in5; + in6 = in7; } - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + // top right pad + float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); + tmp0 = vextq_f32(in0, pad0, 1); + tmp1 = vextq_f32(in0, pad0, 2); + tmp2 = vextq_f32(in2, pad1, 1); + tmp3 = vextq_f32(in2, pad1, 2); - out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmulq_n_f32(in0, w10); + out0 = vmlaq_n_f32(out0, tmp0, w11); + out0 = vmlaq_n_f32(out0, tmp1, w12); + out0 = vmlaq_n_f32(out0, in2, w20); + out0 = vmlaq_n_f32(out0, tmp2, w21); + out0 = vmlaq_n_f32(out0, tmp3, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } + } + + // bottom right pad + float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); + float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + + tmp0 = vextq_f32(in4, pad2, 1); + tmp1 = vextq_f32(in4, pad2, 2); + tmp2 = vextq_f32(in6, pad3, 1); + tmp3 = vextq_f32(in6, pad3, 2); + + out0 = vmulq_n_f32(in4, w00); out0 = vmlaq_n_f32(out0, tmp0, w01); out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, in6, w10); out0 = vmlaq_n_f32(out0, tmp2, w11); out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); out0 = vmlaq_f32(vnewbias, vnewscale, out0); if (if_relu) { out0 = vmaxq_f32(out0, vzero); } for (int i = 0; i < c_mid; ++i) { if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); } if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); } if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); + vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + } + } + // mid + + + for (int i = 0; i < l - 2; ++i) { + auto output_ptr = output_data + (i + 1) * l + 1; + input_tmp = input_data + i * l; + auto in0_tmp = vld1q_f32(input_tmp); + auto in2_tmp = vld1q_f32(input_tmp + l); + auto in4_tmp = vld1q_f32(input_tmp + l + l); + c_mid = l_mid; + for (; c_mid > 3; c_mid -= 4) { + auto in1_tmp = vld1q_f32(input_tmp + 4); + auto in3_tmp = vld1q_f32(input_tmp + l + 4); + auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + + tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); + tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); + tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); + tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); + tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); + tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + vst1q_f32(output_ptr, out0); + + output_ptr += 4; + input_tmp += 4; + in0_tmp = in1_tmp; + in2_tmp = in3_tmp; + in4_tmp = in5_tmp; + } + + float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + + tmp0 = vextq_f32(in0_tmp, pad0, 1); + tmp1 = vextq_f32(in0_tmp, pad0, 2); + tmp2 = vextq_f32(in2_tmp, pad1, 1); + tmp3 = vextq_f32(in2_tmp, pad1, 2); + tmp4 = vextq_f32(in4_tmp, pad2, 1); + tmp5 = vextq_f32(in4_tmp, pad2, 2); + + out0 = vmulq_n_f32(in0_tmp, w00); + out0 = vmlaq_n_f32(out0, tmp0, w01); + out0 = vmlaq_n_f32(out0, tmp1, w02); + out0 = vmlaq_n_f32(out0, in2_tmp, w10); + out0 = vmlaq_n_f32(out0, tmp2, w11); + out0 = vmlaq_n_f32(out0, tmp3, w12); + out0 = vmlaq_n_f32(out0, in4_tmp, w20); + out0 = vmlaq_n_f32(out0, tmp4, w21); + out0 = vmlaq_n_f32(out0, tmp5, w22); + out0 = vmlaq_f32(vnewbias, vnewscale, out0); + if (if_relu) { + out0 = vmaxq_f32(out0, vzero); + } + for (int i = 0; i < c_mid; ++i) { + if (i == 0) { + vst1q_lane_f32(output_ptr + i, out0, 0); + } + if (i == 1) { + vst1q_lane_f32(output_ptr + i, out0, 1); + } + if (i == 2) { + vst1q_lane_f32(output_ptr + i, out0, 2); + } } } + output_data += hxw; + input_data += hxw; + filter_data_tmp += 9; } - output_data += hxw; - input_data += hxw; - filter_data_tmp += 9; } - } + */ #endif } -- GitLab