diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 4065f7d9c4934bce8285ea99fe4f14c4e2cc990c..090ccdf24e214fc86b8a4032df228d50caa65ef9 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -74,7 +74,7 @@ class Im2ColFunctor { const int isize = im_height; bool pad1 = padding[0] > 0; bool pad2 = - (pad1 && + (pad1 && padding[1] && (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); int fill = isize % 2; if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && diff --git a/src/operators/math/pool_3x3.cpp b/src/operators/math/pool_3x3.cpp index 05d3017f635a040a52d2cc377c8f384dbbd8086c..f8b52c59f5689461ef9b4171b9e33c0d49529eed 100644 --- a/src/operators/math/pool_3x3.cpp +++ b/src/operators/math/pool_3x3.cpp @@ -31,186 +31,43 @@ using std::min; using std::vector; void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { #if __ARM_NEON - const int batch_size = input->dims()[0]; + const int batch_size = static_cast(input->dims()[0]); + const int input_channel = static_cast(input->dims()[1]); - const int h_in = input->dims()[2]; + const int input_height = static_cast(input->dims()[2]); + const int input_width = static_cast(input->dims()[3]); + const int output_height = static_cast(output->dims()[2]); + const int output_width = static_cast(output->dims()[3]); - const int w_in = input->dims()[3]; - - const int output_channels = output->dims()[1]; + const int hxw = input_height * input_width; - const int h_out = output->dims()[2]; - const int w_out = output->dims()[3]; - const int outputdata_channel_stride = h_out * w_out; - const int inputdata_channel_stride = h_in * w_in; - const int input_batch_stride = output_channels * inputdata_channel_stride; - const int output_batch_stride = output_channels * outputdata_channel_stride; - float *out_data = output->data(); - const float *input_data = input->data(); + const int l = input_height; const float coef = 1.0 / 9.0; - for (int k = 0; k < batch_size; ++k) { -#pragma omp parallel for - for (int c = 0; c < output_channels; ++c) { - const float *input_seg = input_data + c * inputdata_channel_stride; - float *output_seg = out_data + c * outputdata_channel_stride; - // four corner point - output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] + - input_seg[w_in + 1]) * - coef; - output_seg[w_out - 1] = - (input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 - 2] + - input_seg[2 * w_in - 1]) * - coef; - output_seg[(h_out - 1) * w_out] = - (input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] + - input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1]) * - coef; - output_seg[h_out * w_out - 1] = - (input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] + - input_seg[(h_in - 1) * w_in - 1] + - input_seg[(h_in - 1) * w_in - 2]) * - coef; - // left side & right side - for (int i = 1; i < h_in - 1; ++i) { - output_seg[i * w_out] = - (input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] + - input_seg[i * w_in] + input_seg[i * w_in + 1] + - input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) * - coef; - output_seg[i * w_out + w_out - 1] = - (input_seg[i * w_in - w_in + w_in - 2] + - input_seg[i * w_in - w_in + 1 + w_in - 2] + - input_seg[i * w_in + w_in - 2] + - input_seg[i * w_in + 1 + w_in - 2] + - input_seg[i * w_in + w_in + w_in - 2] + - input_seg[i * w_in + w_in + 1 + w_in - 2]) * - coef; - } - // top 1 row & bottom 1 row - const float *input_tmp = input_seg; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, - tmp3, tmp4, tmp5, sum, out0; - float32x4_t v_coef = vdupq_n_f32(coef); - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + w_in); - const float *input_tmp_end = input_tmp + (h_in - 2) * w_in; - in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + w_in); - int c_mid = w_out - 2; - auto output_ptr = output_seg + 1; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + w_in + 4); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - - sum = vaddq_f32(in0, tmp0); - sum = vaddq_f32(sum, tmp1); - sum = vaddq_f32(sum, in2); - sum = vaddq_f32(sum, tmp2); - sum = vaddq_f32(sum, tmp3); - - vst1q_f32(output_ptr, vmulq_f32(sum, v_coef)); - - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + w_in + 4); - - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); - - sum = vaddq_f32(in0, tmp0); - sum = vaddq_f32(sum, tmp1); - sum = vaddq_f32(sum, in2); - sum = vaddq_f32(sum, tmp2); - sum = vaddq_f32(sum, tmp3); - - vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef)); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; - } - // top right remain - float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]); - float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 2); - tmp3 = vextq_f32(in2, pad1, 2); - - sum = vaddq_f32(in0, tmp0); - sum = vaddq_f32(sum, tmp1); - sum = vaddq_f32(sum, in2); - sum = vaddq_f32(sum, tmp2); - sum = vaddq_f32(sum, tmp3); - out0 = vmulq_f32(sum, v_coef); - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - - // bottom_right remain - float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]); - float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]); - - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 2); - tmp3 = vextq_f32(in6, pad3, 2); - - sum = vaddq_f32(in4, tmp0); - sum = vaddq_f32(sum, tmp1); - sum = vaddq_f32(sum, in6); - sum = vaddq_f32(sum, tmp2); - sum = vaddq_f32(sum, tmp3); - out0 = vmulq_f32(sum, v_coef); + const float coef1 = 1.0 / 6.0; + const float coef2 = 1.0 / 4.0; - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2); - } - } - // mid - for (int j = 0; j < h_out - 2; ++j) { - output_ptr = output_seg + w_out * (j + 1) + 1; - input_tmp = input_seg + j * w_in; + float32x4_t v_coef = vdupq_n_f32(coef); + float32x4_t v_coef1 = vdupq_n_f32(coef1); - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + w_in); - in4 = vld1q_f32(input_tmp + 2 * w_in); - c_mid = w_out - 2; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + w_in + 4); - in5 = vld1q_f32(input_tmp + 2 * w_in + 4); + for (int b = 0; b < batch_size; b++) { +#pragma omp parallel for + for (int c = 0; c < input_channel; c++) { + const float *input_data = input->data() + c * hxw; + float *output_data = output->data() + c * hxw; + + for (int i = 1; i < output_height - 1; i++) { + float *output_ptr; + float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4, + tmp5, out0; + for (int m = 1; m < output_width - 4; m += 4) { + output_ptr = output_data + i * output_width + m; + in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1); + in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3); + in2 = vld1q_f32(input_data + i * input_width + m - 1); + in3 = vld1q_f32(input_data + i * input_width + m + 3); + in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1); + in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); @@ -219,63 +76,383 @@ void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { tmp4 = vextq_f32(in4, in5, 1); tmp5 = vextq_f32(in4, in5, 2); - sum = vaddq_f32(in0, tmp0); - sum = vaddq_f32(sum, tmp1); - sum = vaddq_f32(sum, in2); - sum = vaddq_f32(sum, tmp2); - sum = vaddq_f32(sum, tmp3); - sum = vaddq_f32(sum, in4); - sum = vaddq_f32(sum, tmp4); - sum = vaddq_f32(sum, tmp5); - - out0 = vmulq_f32(sum, v_coef); - vst1q_f32(output_ptr, out0); - output_ptr += 4; - input_tmp += 4; - in0 = in1; - in2 = in3; - in4 = in5; + out0 = in0; + out0 = vaddq_f32(out0, tmp0); + out0 = vaddq_f32(out0, tmp1); + out0 = vaddq_f32(out0, in2); + out0 = vaddq_f32(out0, tmp2); + out0 = vaddq_f32(out0, tmp3); + out0 = vaddq_f32(out0, in4); + out0 = vaddq_f32(out0, tmp4); + out0 = vaddq_f32(out0, tmp5); + + vst1q_f32(output_ptr, vmulq_f32(out0, v_coef)); + } + int m; + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { } - // mid remain - float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]); - float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]); - float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]); - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - tmp4 = vextq_f32(in4, pad2, 1); - tmp5 = vextq_f32(in4, pad2, 2); + for (int j = m; j < output_width - 1; j++) { + output_data[i * output_width + j] = + input_data[(i - 1) * input_width + j - 1] + + input_data[(i - 1) * input_width + j] + + input_data[(i - 1) * input_width + j + 1] + + input_data[(i)*input_width + j - 1] + + input_data[(i)*input_width + j] + + input_data[(i)*input_width + j + 1] + + input_data[(i + 1) * input_width + j - 1] + + input_data[(i + 1) * input_width + j] + + input_data[(i + 1) * input_width + j + 1]; + output_data[i * output_width + j] = + output_data[i * output_width + j] * coef; + } + } - sum = vaddq_f32(in0, tmp0); - sum = vaddq_f32(sum, tmp1); - sum = vaddq_f32(sum, in2); - sum = vaddq_f32(sum, tmp2); - sum = vaddq_f32(sum, tmp3); - sum = vaddq_f32(sum, in4); - sum = vaddq_f32(sum, tmp4); - sum = vaddq_f32(sum, tmp5); - out0 = vmulq_f32(sum, v_coef); + output_data[0] = + input_data[0] + input_data[1] + input_data[l] + input_data[l + 1]; + output_data[l - 1] = input_data[l - 2] + input_data[l - 1] + + input_data[2 * l - 2] + input_data[2 * l - 1]; + output_data[(l - 1) * l] = + input_data[(l - 2) * l] + input_data[(l - 2) * l + 1] + + input_data[(l - 1) * l] + input_data[(l - 1) * l + 1]; + output_data[l * l - 1] = input_data[(l - 2) * (l + 1)] + + input_data[(l - 2) * (l + 1) + 1] + + input_data[l * l - 2] + input_data[l * l - 1]; + output_data[0] = output_data[0] * coef2; + output_data[l - 1] = output_data[l - 1] * coef2; + output_data[(l - 1) * l] = output_data[(l - 1) * l] * coef2; + output_data[l * l - 1] = output_data[l * l - 1] * coef2; + + for (int i = 1; i < l - 1; ++i) { + output_data[i * l] = input_data[i * l - l] + input_data[i * l - l + 1] + + input_data[i * l] + input_data[i * l + 1] + + input_data[i * l + l] + input_data[i * l + l + 1]; + + output_data[i * l + l - 1] = + input_data[i * l + l - 1 - l - 1] + input_data[i * l + l - 1 - l] + + input_data[i * l + l - 1 - 1] + input_data[i * l + l - 1] + + input_data[i * l + l - 1 + l - 1] + input_data[i * l + l - 1 + l]; + output_data[i * l] = output_data[i * l] * coef1; + output_data[i * l + l - 1] = output_data[i * l + l - 1] * coef1; + } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } + int m; + for (m = 1; m < output_width - 4; m += 4) { + float *output_ptr = output_data + m; + float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; + in0 = vld1q_f32(input_data + m - 1); + in1 = vld1q_f32(input_data + m + 3); + in2 = vld1q_f32(input_data + input_width + m - 1); + in3 = vld1q_f32(input_data + input_width + m + 3); + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + out0 = in0; + out0 = vaddq_f32(out0, tmp0); + out0 = vaddq_f32(out0, tmp1); + out0 = vaddq_f32(out0, in2); + out0 = vaddq_f32(out0, tmp2); + out0 = vaddq_f32(out0, tmp3); + + vst1q_f32(output_ptr, vmulq_f32(out0, v_coef1)); + } + + for (m = 1; (m + 3) < output_width - 1; m += 4) { + } + for (int j = m; j < output_width - 1; j++) { + output_data[j] = input_data[j - 1] + input_data[j] + input_data[j + 1] + + input_data[input_width + j - 1] + + input_data[input_width + j] + + input_data[input_width + j + 1]; + output_data[j] = output_data[j] * coef1; + } + + for (m = 1; m < output_width - 4; m += 4) { + float *output_ptr = + output_data + (output_height - 1) * output_width + m; + + float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; + in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1); + in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3); + in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1); + in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3); + tmp0 = vextq_f32(in0, in1, 1); + tmp1 = vextq_f32(in0, in1, 2); + tmp2 = vextq_f32(in2, in3, 1); + tmp3 = vextq_f32(in2, in3, 2); + out0 = in0; + out0 = vaddq_f32(out0, tmp0); + out0 = vaddq_f32(out0, tmp1); + out0 = vaddq_f32(out0, in2); + out0 = vaddq_f32(out0, tmp2); + out0 = vaddq_f32(out0, tmp3); + + vst1q_f32(output_ptr, vmulq_f32(out0, v_coef1)); + } + for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + } + for (int j = m; j < output_width - 1; j++) { + output_data[(output_height - 1) * input_width + j] = + input_data[(output_height - 2) * input_width + j - 1] + + input_data[(output_height - 2) * input_width + j] + + input_data[(output_height - 2) * input_width + j + 1] + + input_data[(output_height - 1) * input_width + j - 1] + + input_data[(output_height - 1) * input_width + j] + + input_data[(output_height - 1) * input_width + j + 1]; + output_data[(output_height - 1) * output_width + j] = + output_data[(output_height - 1) * output_width + j] * coef1; } - // input_data += inputdata_channel_stride; - // out_data += outputdata_channel_stride; } - input_data += input_batch_stride; - out_data += output_batch_stride; } + +// const int batch_size = input->dims()[0]; +// +// const int h_in = input->dims()[2]; +// +// const int w_in = input->dims()[3]; +// +// const int output_channels = output->dims()[1]; +// +// const int h_out = output->dims()[2]; +// const int w_out = output->dims()[3]; +// const int outputdata_channel_stride = h_out * w_out; +// const int inputdata_channel_stride = h_in * w_in; +// const int input_batch_stride = output_channels * inputdata_channel_stride; +// const int output_batch_stride = output_channels * +// outputdata_channel_stride; float *out_data = output->data(); const +// float *input_data = input->data(); +// +// const float coef = 1.0 / 9.0; +// for (int k = 0; k < batch_size; ++k) { +//#pragma omp parallel for +// for (int c = 0; c < output_channels; ++c) { +// const float *input_seg = input_data + c * inputdata_channel_stride; +// float *output_seg = out_data + c * outputdata_channel_stride; +// // four corner point +// output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] + +// input_seg[w_in + 1]) * +// coef; +// output_seg[w_out - 1] = +// (input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 - +// 2] + +// input_seg[2 * w_in - 1]) * +// coef; +// output_seg[(h_out - 1) * w_out] = +// (input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] + +// input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1]) +// * +// coef; +// output_seg[h_out * w_out - 1] = +// (input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] + +// input_seg[(h_in - 1) * w_in - 1] + +// input_seg[(h_in - 1) * w_in - 2]) * +// coef; +// // left side & right side +// for (int i = 1; i < h_in - 1; ++i) { +// output_seg[i * w_out] = +// (input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] + +// input_seg[i * w_in] + input_seg[i * w_in + 1] + +// input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) * +// coef; +// output_seg[i * w_out + w_out - 1] = +// (input_seg[i * w_in - w_in + w_in - 2] + +// input_seg[i * w_in - w_in + 1 + w_in - 2] + +// input_seg[i * w_in + w_in - 2] + +// input_seg[i * w_in + 1 + w_in - 2] + +// input_seg[i * w_in + w_in + w_in - 2] + +// input_seg[i * w_in + w_in + 1 + w_in - 2]) * +// coef; +// } +// // top 1 row & bottom 1 row +// const float *input_tmp = input_seg; +// +// float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, +// tmp3, tmp4, tmp5, sum, out0; +// float32x4_t v_coef = vdupq_n_f32(coef); +// in0 = vld1q_f32(input_tmp); +// in2 = vld1q_f32(input_tmp + w_in); +// const float *input_tmp_end = input_tmp + (h_in - 2) * w_in; +// in4 = vld1q_f32(input_tmp_end); +// in6 = vld1q_f32(input_tmp_end + w_in); +// int c_mid = w_out - 2; +// auto output_ptr = output_seg + 1; +// for (; c_mid > 3; c_mid -= 4) { +// in1 = vld1q_f32(input_tmp + 4); +// in3 = vld1q_f32(input_tmp + w_in + 4); +// +// tmp0 = vextq_f32(in0, in1, 1); +// tmp1 = vextq_f32(in0, in1, 2); +// +// tmp2 = vextq_f32(in2, in3, 1); +// tmp3 = vextq_f32(in2, in3, 2); +// +// sum = vaddq_f32(in0, tmp0); +// sum = vaddq_f32(sum, tmp1); +// sum = vaddq_f32(sum, in2); +// sum = vaddq_f32(sum, tmp2); +// sum = vaddq_f32(sum, tmp3); +// +// vst1q_f32(output_ptr, vmulq_f32(sum, v_coef)); +// +// in5 = vld1q_f32(input_tmp_end + 4); +// in7 = vld1q_f32(input_tmp_end + w_in + 4); +// +// tmp0 = vextq_f32(in4, in5, 1); +// tmp1 = vextq_f32(in4, in5, 2); +// tmp2 = vextq_f32(in6, in7, 1); +// tmp3 = vextq_f32(in6, in7, 2); +// +// sum = vaddq_f32(in0, tmp0); +// sum = vaddq_f32(sum, tmp1); +// sum = vaddq_f32(sum, in2); +// sum = vaddq_f32(sum, tmp2); +// sum = vaddq_f32(sum, tmp3); +// +// vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef)); +// +// // can optimize to each 8 stride. +// input_tmp += 4; +// input_tmp_end += 4; +// output_ptr += 4; +// in0 = in1; +// in2 = in3; +// in4 = in5; +// in6 = in7; +// } +// // top right remain +// float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]); +// float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]); +// +// tmp0 = vextq_f32(in0, pad0, 1); +// tmp1 = vextq_f32(in0, pad0, 2); +// tmp2 = vextq_f32(in2, pad1, 2); +// tmp3 = vextq_f32(in2, pad1, 2); +// +// sum = vaddq_f32(in0, tmp0); +// sum = vaddq_f32(sum, tmp1); +// sum = vaddq_f32(sum, in2); +// sum = vaddq_f32(sum, tmp2); +// sum = vaddq_f32(sum, tmp3); +// out0 = vmulq_f32(sum, v_coef); +// +// for (int i = 0; i < c_mid; ++i) { +// if (i == 0) { +// vst1q_lane_f32(output_ptr + i, out0, 0); +// } +// if (i == 1) { +// vst1q_lane_f32(output_ptr + i, out0, 1); +// } +// if (i == 2) { +// vst1q_lane_f32(output_ptr + i, out0, 2); +// } +// } +// +// // bottom_right remain +// float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]); +// float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]); +// +// tmp0 = vextq_f32(in4, pad2, 1); +// tmp1 = vextq_f32(in4, pad2, 2); +// tmp2 = vextq_f32(in6, pad3, 2); +// tmp3 = vextq_f32(in6, pad3, 2); +// +// sum = vaddq_f32(in4, tmp0); +// sum = vaddq_f32(sum, tmp1); +// sum = vaddq_f32(sum, in6); +// sum = vaddq_f32(sum, tmp2); +// sum = vaddq_f32(sum, tmp3); +// out0 = vmulq_f32(sum, v_coef); +// +// for (int i = 0; i < c_mid; ++i) { +// if (i == 0) { +// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0); +// } +// if (i == 1) { +// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1); +// } +// if (i == 2) { +// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2); +// } +// } +// // mid +// for (int j = 0; j < h_out - 2; ++j) { +// output_ptr = output_seg + w_out * (j + 1) + 1; +// input_tmp = input_seg + j * w_in; +// +// in0 = vld1q_f32(input_tmp); +// in2 = vld1q_f32(input_tmp + w_in); +// in4 = vld1q_f32(input_tmp + 2 * w_in); +// c_mid = w_out - 2; +// for (; c_mid > 3; c_mid -= 4) { +// in1 = vld1q_f32(input_tmp + 4); +// in3 = vld1q_f32(input_tmp + w_in + 4); +// in5 = vld1q_f32(input_tmp + 2 * w_in + 4); +// +// tmp0 = vextq_f32(in0, in1, 1); +// tmp1 = vextq_f32(in0, in1, 2); +// tmp2 = vextq_f32(in2, in3, 1); +// tmp3 = vextq_f32(in2, in3, 2); +// tmp4 = vextq_f32(in4, in5, 1); +// tmp5 = vextq_f32(in4, in5, 2); +// +// sum = vaddq_f32(in0, tmp0); +// sum = vaddq_f32(sum, tmp1); +// sum = vaddq_f32(sum, in2); +// sum = vaddq_f32(sum, tmp2); +// sum = vaddq_f32(sum, tmp3); +// sum = vaddq_f32(sum, in4); +// sum = vaddq_f32(sum, tmp4); +// sum = vaddq_f32(sum, tmp5); +// +// out0 = vmulq_f32(sum, v_coef); +// vst1q_f32(output_ptr, out0); +// output_ptr += 4; +// input_tmp += 4; +// in0 = in1; +// in2 = in3; +// in4 = in5; +// } +// // mid remain +// float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]); +// float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]); +// float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]); +// +// tmp0 = vextq_f32(in0, pad0, 1); +// tmp1 = vextq_f32(in0, pad0, 2); +// tmp2 = vextq_f32(in2, pad1, 1); +// tmp3 = vextq_f32(in2, pad1, 2); +// tmp4 = vextq_f32(in4, pad2, 1); +// tmp5 = vextq_f32(in4, pad2, 2); +// +// sum = vaddq_f32(in0, tmp0); +// sum = vaddq_f32(sum, tmp1); +// sum = vaddq_f32(sum, in2); +// sum = vaddq_f32(sum, tmp2); +// sum = vaddq_f32(sum, tmp3); +// sum = vaddq_f32(sum, in4); +// sum = vaddq_f32(sum, tmp4); +// sum = vaddq_f32(sum, tmp5); +// out0 = vmulq_f32(sum, v_coef); +// +// for (int i = 0; i < c_mid; ++i) { +// if (i == 0) { +// vst1q_lane_f32(output_ptr + i, out0, 0); +// } +// if (i == 1) { +// vst1q_lane_f32(output_ptr + i, out0, 1); +// } +// if (i == 2) { +// vst1q_lane_f32(output_ptr + i, out0, 2); +// } +// } +// } +// // input_data += inputdata_channel_stride; +// // out_data += outputdata_channel_stride; +// } +// input_data += input_batch_stride; +// out_data += output_batch_stride; +// } #endif } @@ -662,6 +839,7 @@ void Pool3x3Avg(vector strides, vector paddings, const Tensor *input, wstart = max(wstart, 0); hend = min(hend, input_height); wend = min(wend, input_width); + const float *pos1 = input_seg + hstart * input_width + wstart; const float *pos2 = input_seg + (hstart + 1) * input_width + wstart; const float *pos3 = input_seg + (hstart + 2) * input_width + wstart; @@ -674,7 +852,8 @@ void Pool3x3Avg(vector strides, vector paddings, const Tensor *input, sum += input_seg[h * input_width + w]; } } - output_seg[ph * output_width + pw] = sum / 9.0; + output_seg[ph * output_width + pw] = + sum / ((hend - hstart) * (wend - wstart) * 1.0); } else { #if __aarch64__ #else