提交 58a9e04b 编写于 作者: Z zhangyang0701 提交者: GitHub

Merge pull request #952 from yangfei963158659/develop

repair bug of pool3x3
...@@ -31,251 +31,428 @@ using std::min; ...@@ -31,251 +31,428 @@ using std::min;
using std::vector; using std::vector;
void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) { void Pool3x3Avgs1p1(const Tensor *input, Tensor *output) {
#if __ARM_NEON #if __ARM_NEON
const int batch_size = input->dims()[0]; const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]);
const int h_in = input->dims()[2]; const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]);
const int w_in = input->dims()[3]; const int hxw = input_height * input_width;
const int output_channels = output->dims()[1]; const int l = input_height;
const int h_out = output->dims()[2];
const int w_out = output->dims()[3];
const int outputdata_channel_stride = h_out * w_out;
const int inputdata_channel_stride = h_in * w_in;
const int input_batch_stride = output_channels * inputdata_channel_stride;
const int output_batch_stride = output_channels * outputdata_channel_stride;
float *out_data = output->data<float>();
const float *input_data = input->data<float>();
const float coef = 1.0 / 9.0; const float coef = 1.0 / 9.0;
for (int k = 0; k < batch_size; ++k) { const float coef1 = 1.0 / 6.0;
#pragma omp parallel for const float coef2 = 1.0 / 4.0;
for (int c = 0; c < output_channels; ++c) {
const float *input_seg = input_data + c * inputdata_channel_stride;
float *output_seg = out_data + c * outputdata_channel_stride;
// four corner point
output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] +
input_seg[w_in + 1]) *
coef;
output_seg[w_out - 1] =
(input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 - 2] +
input_seg[2 * w_in - 1]) *
coef;
output_seg[(h_out - 1) * w_out] =
(input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] +
input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1]) *
coef;
output_seg[h_out * w_out - 1] =
(input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] +
input_seg[(h_in - 1) * w_in - 1] +
input_seg[(h_in - 1) * w_in - 2]) *
coef;
// left side & right side
for (int i = 1; i < h_in - 1; ++i) {
output_seg[i * w_out] =
(input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] +
input_seg[i * w_in] + input_seg[i * w_in + 1] +
input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) *
coef;
output_seg[i * w_out + w_out - 1] =
(input_seg[i * w_in - w_in + w_in - 2] +
input_seg[i * w_in - w_in + 1 + w_in - 2] +
input_seg[i * w_in + w_in - 2] +
input_seg[i * w_in + 1 + w_in - 2] +
input_seg[i * w_in + w_in + w_in - 2] +
input_seg[i * w_in + w_in + 1 + w_in - 2]) *
coef;
}
// top 1 row & bottom 1 row
const float *input_tmp = input_seg;
float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
tmp3, tmp4, tmp5, sum, out0;
float32x4_t v_coef = vdupq_n_f32(coef); float32x4_t v_coef = vdupq_n_f32(coef);
in0 = vld1q_f32(input_tmp); float32x4_t v_coef1 = vdupq_n_f32(coef1);
in2 = vld1q_f32(input_tmp + w_in);
const float *input_tmp_end = input_tmp + (h_in - 2) * w_in; for (int b = 0; b < batch_size; b++) {
in4 = vld1q_f32(input_tmp_end); #pragma omp parallel for
in6 = vld1q_f32(input_tmp_end + w_in); for (int c = 0; c < input_channel; c++) {
int c_mid = w_out - 2; const float *input_data = input->data<float>() + c * hxw;
auto output_ptr = output_seg + 1; float *output_data = output->data<float>() + c * hxw;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4); for (int i = 1; i < output_height - 1; i++) {
in3 = vld1q_f32(input_tmp + w_in + 4); float *output_ptr;
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4,
tmp5, out0;
for (int m = 1; m < output_width - 4; m += 4) {
output_ptr = output_data + i * output_width + m;
in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1);
in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3);
in2 = vld1q_f32(input_data + i * input_width + m - 1);
in3 = vld1q_f32(input_data + i * input_width + m + 3);
in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1);
in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2); tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1); tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2); tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1);
tmp5 = vextq_f32(in4, in5, 2);
sum = vaddq_f32(in0, tmp0); out0 = in0;
sum = vaddq_f32(sum, tmp1); out0 = vaddq_f32(out0, tmp0);
sum = vaddq_f32(sum, in2); out0 = vaddq_f32(out0, tmp1);
sum = vaddq_f32(sum, tmp2); out0 = vaddq_f32(out0, in2);
sum = vaddq_f32(sum, tmp3); out0 = vaddq_f32(out0, tmp2);
out0 = vaddq_f32(out0, tmp3);
vst1q_f32(output_ptr, vmulq_f32(sum, v_coef)); out0 = vaddq_f32(out0, in4);
out0 = vaddq_f32(out0, tmp4);
in5 = vld1q_f32(input_tmp_end + 4); out0 = vaddq_f32(out0, tmp5);
in7 = vld1q_f32(input_tmp_end + w_in + 4);
vst1q_f32(output_ptr, vmulq_f32(out0, v_coef));
tmp0 = vextq_f32(in4, in5, 1); }
tmp1 = vextq_f32(in4, in5, 2); int m;
tmp2 = vextq_f32(in6, in7, 1); for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
tmp3 = vextq_f32(in6, in7, 2); }
sum = vaddq_f32(in0, tmp0); for (int j = m; j < output_width - 1; j++) {
sum = vaddq_f32(sum, tmp1); output_data[i * output_width + j] =
sum = vaddq_f32(sum, in2); input_data[(i - 1) * input_width + j - 1] +
sum = vaddq_f32(sum, tmp2); input_data[(i - 1) * input_width + j] +
sum = vaddq_f32(sum, tmp3); input_data[(i - 1) * input_width + j + 1] +
input_data[(i)*input_width + j - 1] +
vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef)); input_data[(i)*input_width + j] +
input_data[(i)*input_width + j + 1] +
// can optimize to each 8 stride. input_data[(i + 1) * input_width + j - 1] +
input_tmp += 4; input_data[(i + 1) * input_width + j] +
input_tmp_end += 4; input_data[(i + 1) * input_width + j + 1];
output_ptr += 4; output_data[i * output_width + j] =
in0 = in1; output_data[i * output_width + j] * coef;
in2 = in3; }
in4 = in5; }
in6 = in7;
} output_data[0] =
// top right remain input_data[0] + input_data[1] + input_data[l] + input_data[l + 1];
float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]); output_data[l - 1] = input_data[l - 2] + input_data[l - 1] +
float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]); input_data[2 * l - 2] + input_data[2 * l - 1];
output_data[(l - 1) * l] =
tmp0 = vextq_f32(in0, pad0, 1); input_data[(l - 2) * l] + input_data[(l - 2) * l + 1] +
tmp1 = vextq_f32(in0, pad0, 2); input_data[(l - 1) * l] + input_data[(l - 1) * l + 1];
tmp2 = vextq_f32(in2, pad1, 2); output_data[l * l - 1] = input_data[(l - 2) * (l + 1)] +
tmp3 = vextq_f32(in2, pad1, 2); input_data[(l - 2) * (l + 1) + 1] +
input_data[l * l - 2] + input_data[l * l - 1];
sum = vaddq_f32(in0, tmp0); output_data[0] = output_data[0] * coef2;
sum = vaddq_f32(sum, tmp1); output_data[l - 1] = output_data[l - 1] * coef2;
sum = vaddq_f32(sum, in2); output_data[(l - 1) * l] = output_data[(l - 1) * l] * coef2;
sum = vaddq_f32(sum, tmp2); output_data[l * l - 1] = output_data[l * l - 1] * coef2;
sum = vaddq_f32(sum, tmp3);
out0 = vmulq_f32(sum, v_coef); for (int i = 1; i < l - 1; ++i) {
output_data[i * l] = input_data[i * l - l] + input_data[i * l - l + 1] +
input_data[i * l] + input_data[i * l + 1] +
input_data[i * l + l] + input_data[i * l + l + 1];
output_data[i * l + l - 1] =
input_data[i * l + l - 1 - l - 1] + input_data[i * l + l - 1 - l] +
input_data[i * l + l - 1 - 1] + input_data[i * l + l - 1] +
input_data[i * l + l - 1 + l - 1] + input_data[i * l + l - 1 + l];
output_data[i * l] = output_data[i * l] * coef1;
output_data[i * l + l - 1] = output_data[i * l + l - 1] * coef1;
}
int m;
for (m = 1; m < output_width - 4; m += 4) {
float *output_ptr = output_data + m;
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + m - 1);
in1 = vld1q_f32(input_data + m + 3);
in2 = vld1q_f32(input_data + input_width + m - 1);
in3 = vld1q_f32(input_data + input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2);
out0 = in0;
out0 = vaddq_f32(out0, tmp0);
out0 = vaddq_f32(out0, tmp1);
out0 = vaddq_f32(out0, in2);
out0 = vaddq_f32(out0, tmp2);
out0 = vaddq_f32(out0, tmp3);
for (int i = 0; i < c_mid; ++i) { vst1q_f32(output_ptr, vmulq_f32(out0, v_coef1));
if (i == 0) {
vst1q_lane_f32(output_ptr + i, out0, 0);
}
if (i == 1) {
vst1q_lane_f32(output_ptr + i, out0, 1);
}
if (i == 2) {
vst1q_lane_f32(output_ptr + i, out0, 2);
}
} }
// bottom_right remain for (m = 1; (m + 3) < output_width - 1; m += 4) {
float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
tmp0 = vextq_f32(in4, pad2, 1);
tmp1 = vextq_f32(in4, pad2, 2);
tmp2 = vextq_f32(in6, pad3, 2);
tmp3 = vextq_f32(in6, pad3, 2);
sum = vaddq_f32(in4, tmp0);
sum = vaddq_f32(sum, tmp1);
sum = vaddq_f32(sum, in6);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3);
out0 = vmulq_f32(sum, v_coef);
for (int i = 0; i < c_mid; ++i) {
if (i == 0) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0);
} }
if (i == 1) { for (int j = m; j < output_width - 1; j++) {
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1); output_data[j] = input_data[j - 1] + input_data[j] + input_data[j + 1] +
} input_data[input_width + j - 1] +
if (i == 2) { input_data[input_width + j] +
vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2); input_data[input_width + j + 1];
output_data[j] = output_data[j] * coef1;
} }
}
// mid
for (int j = 0; j < h_out - 2; ++j) {
output_ptr = output_seg + w_out * (j + 1) + 1;
input_tmp = input_seg + j * w_in;
in0 = vld1q_f32(input_tmp); for (m = 1; m < output_width - 4; m += 4) {
in2 = vld1q_f32(input_tmp + w_in); float *output_ptr =
in4 = vld1q_f32(input_tmp + 2 * w_in); output_data + (output_height - 1) * output_width + m;
c_mid = w_out - 2;
for (; c_mid > 3; c_mid -= 4) {
in1 = vld1q_f32(input_tmp + 4);
in3 = vld1q_f32(input_tmp + w_in + 4);
in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1);
in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3);
in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1);
in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3);
tmp0 = vextq_f32(in0, in1, 1); tmp0 = vextq_f32(in0, in1, 1);
tmp1 = vextq_f32(in0, in1, 2); tmp1 = vextq_f32(in0, in1, 2);
tmp2 = vextq_f32(in2, in3, 1); tmp2 = vextq_f32(in2, in3, 1);
tmp3 = vextq_f32(in2, in3, 2); tmp3 = vextq_f32(in2, in3, 2);
tmp4 = vextq_f32(in4, in5, 1); out0 = in0;
tmp5 = vextq_f32(in4, in5, 2); out0 = vaddq_f32(out0, tmp0);
out0 = vaddq_f32(out0, tmp1);
sum = vaddq_f32(in0, tmp0); out0 = vaddq_f32(out0, in2);
sum = vaddq_f32(sum, tmp1); out0 = vaddq_f32(out0, tmp2);
sum = vaddq_f32(sum, in2); out0 = vaddq_f32(out0, tmp3);
sum = vaddq_f32(sum, tmp2);
sum = vaddq_f32(sum, tmp3); vst1q_f32(output_ptr, vmulq_f32(out0, v_coef1));
sum = vaddq_f32(sum, in4); }
sum = vaddq_f32(sum, tmp4); for (m = 1; (m + 3) < output_width - 1; m = m + 4) {
sum = vaddq_f32(sum, tmp5); }
for (int j = m; j < output_width - 1; j++) {
out0 = vmulq_f32(sum, v_coef); output_data[(output_height - 1) * input_width + j] =
vst1q_f32(output_ptr, out0); input_data[(output_height - 2) * input_width + j - 1] +
output_ptr += 4; input_data[(output_height - 2) * input_width + j] +
input_tmp += 4; input_data[(output_height - 2) * input_width + j + 1] +
in0 = in1; input_data[(output_height - 1) * input_width + j - 1] +
in2 = in3; input_data[(output_height - 1) * input_width + j] +
in4 = in5; input_data[(output_height - 1) * input_width + j + 1];
} output_data[(output_height - 1) * output_width + j] =
// mid remain output_data[(output_height - 1) * output_width + j] * coef1;
float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]); }
float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]); }
float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]); }
tmp0 = vextq_f32(in0, pad0, 1); // const int batch_size = input->dims()[0];
tmp1 = vextq_f32(in0, pad0, 2); //
tmp2 = vextq_f32(in2, pad1, 1); // const int h_in = input->dims()[2];
tmp3 = vextq_f32(in2, pad1, 2); //
tmp4 = vextq_f32(in4, pad2, 1); // const int w_in = input->dims()[3];
tmp5 = vextq_f32(in4, pad2, 2); //
// const int output_channels = output->dims()[1];
sum = vaddq_f32(in0, tmp0); //
sum = vaddq_f32(sum, tmp1); // const int h_out = output->dims()[2];
sum = vaddq_f32(sum, in2); // const int w_out = output->dims()[3];
sum = vaddq_f32(sum, tmp2); // const int outputdata_channel_stride = h_out * w_out;
sum = vaddq_f32(sum, tmp3); // const int inputdata_channel_stride = h_in * w_in;
sum = vaddq_f32(sum, in4); // const int input_batch_stride = output_channels * inputdata_channel_stride;
sum = vaddq_f32(sum, tmp4); // const int output_batch_stride = output_channels *
sum = vaddq_f32(sum, tmp5); // outputdata_channel_stride; float *out_data = output->data<float>(); const
out0 = vmulq_f32(sum, v_coef); // float *input_data = input->data<float>();
//
for (int i = 0; i < c_mid; ++i) { // const float coef = 1.0 / 9.0;
if (i == 0) { // for (int k = 0; k < batch_size; ++k) {
vst1q_lane_f32(output_ptr + i, out0, 0); //#pragma omp parallel for
} // for (int c = 0; c < output_channels; ++c) {
if (i == 1) { // const float *input_seg = input_data + c * inputdata_channel_stride;
vst1q_lane_f32(output_ptr + i, out0, 1); // float *output_seg = out_data + c * outputdata_channel_stride;
} // // four corner point
if (i == 2) { // output_seg[0] = (input_seg[0] + input_seg[1] + input_seg[w_in] +
vst1q_lane_f32(output_ptr + i, out0, 2); // input_seg[w_in + 1]) *
} // coef;
} // output_seg[w_out - 1] =
} // (input_seg[w_in - 2] + input_seg[w_in - 1] + input_seg[w_in * 2 -
// input_data += inputdata_channel_stride; // 2] +
// out_data += outputdata_channel_stride; // input_seg[2 * w_in - 1]) *
} // coef;
input_data += input_batch_stride; // output_seg[(h_out - 1) * w_out] =
out_data += output_batch_stride; // (input_seg[(h_in - 2) * w_in] + input_seg[(h_in - 2) * w_in + 1] +
} // input_seg[(h_in - 1) * w_in] + input_seg[(h_in - 1) * w_in + 1])
// *
// coef;
// output_seg[h_out * w_out - 1] =
// (input_seg[h_in * w_in - 1] + input_seg[h_in * w_in - 2] +
// input_seg[(h_in - 1) * w_in - 1] +
// input_seg[(h_in - 1) * w_in - 2]) *
// coef;
// // left side & right side
// for (int i = 1; i < h_in - 1; ++i) {
// output_seg[i * w_out] =
// (input_seg[i * w_in - w_in] + input_seg[i * w_in - w_in + 1] +
// input_seg[i * w_in] + input_seg[i * w_in + 1] +
// input_seg[i * w_in + w_in] + input_seg[i * w_in + w_in + 1]) *
// coef;
// output_seg[i * w_out + w_out - 1] =
// (input_seg[i * w_in - w_in + w_in - 2] +
// input_seg[i * w_in - w_in + 1 + w_in - 2] +
// input_seg[i * w_in + w_in - 2] +
// input_seg[i * w_in + 1 + w_in - 2] +
// input_seg[i * w_in + w_in + w_in - 2] +
// input_seg[i * w_in + w_in + 1 + w_in - 2]) *
// coef;
// }
// // top 1 row & bottom 1 row
// const float *input_tmp = input_seg;
//
// float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2,
// tmp3, tmp4, tmp5, sum, out0;
// float32x4_t v_coef = vdupq_n_f32(coef);
// in0 = vld1q_f32(input_tmp);
// in2 = vld1q_f32(input_tmp + w_in);
// const float *input_tmp_end = input_tmp + (h_in - 2) * w_in;
// in4 = vld1q_f32(input_tmp_end);
// in6 = vld1q_f32(input_tmp_end + w_in);
// int c_mid = w_out - 2;
// auto output_ptr = output_seg + 1;
// for (; c_mid > 3; c_mid -= 4) {
// in1 = vld1q_f32(input_tmp + 4);
// in3 = vld1q_f32(input_tmp + w_in + 4);
//
// tmp0 = vextq_f32(in0, in1, 1);
// tmp1 = vextq_f32(in0, in1, 2);
//
// tmp2 = vextq_f32(in2, in3, 1);
// tmp3 = vextq_f32(in2, in3, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
//
// vst1q_f32(output_ptr, vmulq_f32(sum, v_coef));
//
// in5 = vld1q_f32(input_tmp_end + 4);
// in7 = vld1q_f32(input_tmp_end + w_in + 4);
//
// tmp0 = vextq_f32(in4, in5, 1);
// tmp1 = vextq_f32(in4, in5, 2);
// tmp2 = vextq_f32(in6, in7, 1);
// tmp3 = vextq_f32(in6, in7, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
//
// vst1q_f32(output_ptr + (h_out - 1) * w_out, vmulq_f32(sum, v_coef));
//
// // can optimize to each 8 stride.
// input_tmp += 4;
// input_tmp_end += 4;
// output_ptr += 4;
// in0 = in1;
// in2 = in3;
// in4 = in5;
// in6 = in7;
// }
// // top right remain
// float32x4_t pad0 = vdupq_n_f32(input_seg[w_in - 1]);
// float32x4_t pad1 = vdupq_n_f32(input_seg[2 * w_in - 1]);
//
// tmp0 = vextq_f32(in0, pad0, 1);
// tmp1 = vextq_f32(in0, pad0, 2);
// tmp2 = vextq_f32(in2, pad1, 2);
// tmp3 = vextq_f32(in2, pad1, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// out0 = vmulq_f32(sum, v_coef);
//
// for (int i = 0; i < c_mid; ++i) {
// if (i == 0) {
// vst1q_lane_f32(output_ptr + i, out0, 0);
// }
// if (i == 1) {
// vst1q_lane_f32(output_ptr + i, out0, 1);
// }
// if (i == 2) {
// vst1q_lane_f32(output_ptr + i, out0, 2);
// }
// }
//
// // bottom_right remain
// float32x4_t pad2 = vdupq_n_f32(input_seg[(h_in - 1) * w_in - 1]);
// float32x4_t pad3 = vdupq_n_f32(input_seg[h_in * w_in - 1]);
//
// tmp0 = vextq_f32(in4, pad2, 1);
// tmp1 = vextq_f32(in4, pad2, 2);
// tmp2 = vextq_f32(in6, pad3, 2);
// tmp3 = vextq_f32(in6, pad3, 2);
//
// sum = vaddq_f32(in4, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in6);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// out0 = vmulq_f32(sum, v_coef);
//
// for (int i = 0; i < c_mid; ++i) {
// if (i == 0) {
// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 0);
// }
// if (i == 1) {
// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 1);
// }
// if (i == 2) {
// vst1q_lane_f32(output_ptr + (h_out - 1) * w_out + i, out0, 2);
// }
// }
// // mid
// for (int j = 0; j < h_out - 2; ++j) {
// output_ptr = output_seg + w_out * (j + 1) + 1;
// input_tmp = input_seg + j * w_in;
//
// in0 = vld1q_f32(input_tmp);
// in2 = vld1q_f32(input_tmp + w_in);
// in4 = vld1q_f32(input_tmp + 2 * w_in);
// c_mid = w_out - 2;
// for (; c_mid > 3; c_mid -= 4) {
// in1 = vld1q_f32(input_tmp + 4);
// in3 = vld1q_f32(input_tmp + w_in + 4);
// in5 = vld1q_f32(input_tmp + 2 * w_in + 4);
//
// tmp0 = vextq_f32(in0, in1, 1);
// tmp1 = vextq_f32(in0, in1, 2);
// tmp2 = vextq_f32(in2, in3, 1);
// tmp3 = vextq_f32(in2, in3, 2);
// tmp4 = vextq_f32(in4, in5, 1);
// tmp5 = vextq_f32(in4, in5, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// sum = vaddq_f32(sum, in4);
// sum = vaddq_f32(sum, tmp4);
// sum = vaddq_f32(sum, tmp5);
//
// out0 = vmulq_f32(sum, v_coef);
// vst1q_f32(output_ptr, out0);
// output_ptr += 4;
// input_tmp += 4;
// in0 = in1;
// in2 = in3;
// in4 = in5;
// }
// // mid remain
// float32x4_t pad0 = vdupq_n_f32(input_seg[(j + 1) * w_in - 1]);
// float32x4_t pad1 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
// float32x4_t pad2 = vdupq_n_f32(input_seg[(j + 2) * w_in - 1]);
//
// tmp0 = vextq_f32(in0, pad0, 1);
// tmp1 = vextq_f32(in0, pad0, 2);
// tmp2 = vextq_f32(in2, pad1, 1);
// tmp3 = vextq_f32(in2, pad1, 2);
// tmp4 = vextq_f32(in4, pad2, 1);
// tmp5 = vextq_f32(in4, pad2, 2);
//
// sum = vaddq_f32(in0, tmp0);
// sum = vaddq_f32(sum, tmp1);
// sum = vaddq_f32(sum, in2);
// sum = vaddq_f32(sum, tmp2);
// sum = vaddq_f32(sum, tmp3);
// sum = vaddq_f32(sum, in4);
// sum = vaddq_f32(sum, tmp4);
// sum = vaddq_f32(sum, tmp5);
// out0 = vmulq_f32(sum, v_coef);
//
// for (int i = 0; i < c_mid; ++i) {
// if (i == 0) {
// vst1q_lane_f32(output_ptr + i, out0, 0);
// }
// if (i == 1) {
// vst1q_lane_f32(output_ptr + i, out0, 1);
// }
// if (i == 2) {
// vst1q_lane_f32(output_ptr + i, out0, 2);
// }
// }
// }
// // input_data += inputdata_channel_stride;
// // out_data += outputdata_channel_stride;
// }
// input_data += input_batch_stride;
// out_data += output_batch_stride;
// }
#endif #endif
} }
...@@ -662,6 +839,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -662,6 +839,7 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
wstart = max(wstart, 0); wstart = max(wstart, 0);
hend = min(hend, input_height); hend = min(hend, input_height);
wend = min(wend, input_width); wend = min(wend, input_width);
const float *pos1 = input_seg + hstart * input_width + wstart; const float *pos1 = input_seg + hstart * input_width + wstart;
const float *pos2 = input_seg + (hstart + 1) * input_width + wstart; const float *pos2 = input_seg + (hstart + 1) * input_width + wstart;
const float *pos3 = input_seg + (hstart + 2) * input_width + wstart; const float *pos3 = input_seg + (hstart + 2) * input_width + wstart;
...@@ -674,7 +852,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input, ...@@ -674,7 +852,8 @@ void Pool3x3Avg(vector<int> strides, vector<int> paddings, const Tensor *input,
sum += input_seg[h * input_width + w]; sum += input_seg[h * input_width + w];
} }
} }
output_seg[ph * output_width + pw] = sum / 9.0; output_seg[ph * output_width + pw] =
sum / ((hend - hstart) * (wend - wstart) * 1.0);
} else { } else {
#if __aarch64__ #if __aarch64__
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册