提交 2fd6bd46 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #944 from yangfei963158659/develop

repair bug of multithreading depthwise_conv3x3(s2)
...@@ -1465,180 +1465,187 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1465,180 +1465,187 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
Tensor *output, const Tensor *new_scale, Tensor *output, const Tensor *new_scale,
const Tensor *new_bias, bool if_relu) { const Tensor *new_bias, bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
#ifdef _OPENMP //#ifdef _OPENMP
const float *newscale_data = new_scale->data<float>(); // const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); // const float *newbias_data = new_bias->data<float>();
//
const int batch_size = static_cast<int>(input->dims()[0]); // const int batch_size = static_cast<int>(input->dims()[0]);
const int input_channel = static_cast<int>(input->dims()[1]); // const int input_channel = static_cast<int>(input->dims()[1]);
//
const int input_height = static_cast<int>(input->dims()[2]); // const int input_height = static_cast<int>(input->dims()[2]);
const int input_width = static_cast<int>(input->dims()[3]); // const int input_width = static_cast<int>(input->dims()[3]);
const int output_height = static_cast<int>(output->dims()[2]); // const int output_height = static_cast<int>(output->dims()[2]);
const int output_width = static_cast<int>(output->dims()[3]); // const int output_width = static_cast<int>(output->dims()[3]);
const int inhxw = input_height * input_width; // const int inhxw = input_height * input_width;
const int outhxw = output_height * output_width; // const int outhxw = output_height * output_width;
//
float32x4_t zero = vdupq_n_f32(0.0); // float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; b++) { // for (int b = 0; b < batch_size; b++) {
#pragma omp parallel for // #pragma omp parallel for
for (int c = 0; c < input_channel; c++) { // for (int c = 0; c < input_channel; c++) {
const float *filter_data = filter->data<float>() + c * 9; // const float *filter_data = filter->data<float>() + c * 9;
const float *input_data = input->data<float>() + c * inhxw; // const float *input_data = input->data<float>() + c * inhxw;
float *output_data = output->data<float>() + c * outhxw; // float *output_data = output->data<float>() + c * outhxw;
float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]); // float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]); // float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
//
float w00 = filter_data[0]; // float w00 = filter_data[0];
float w01 = filter_data[1]; // float w01 = filter_data[1];
float w02 = filter_data[2]; // float w02 = filter_data[2];
float w10 = filter_data[3]; // float w10 = filter_data[3];
float w11 = filter_data[4]; // float w11 = filter_data[4];
float w12 = filter_data[5]; // float w12 = filter_data[5];
float w20 = filter_data[6]; // float w20 = filter_data[6];
float w21 = filter_data[7]; // float w21 = filter_data[7];
float w22 = filter_data[8]; // float w22 = filter_data[8];
//
int m; // int m;
for (m = 1; m < output_width - 2; m = m + 3) { // for (m = 1; m < output_width - 2; m = m + 3) {
float *output_ptr = output_data + m; // float *output_ptr = output_data + m;
float32x4x2_t input_buff_mid{}, input_buff_bottom{}; // float32x4x2_t input_buff_mid{}, input_buff_bottom{};
float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; // float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0;
input_buff_mid = vld2q_f32(input_data + (2 * m - 1)); // input_buff_mid = vld2q_f32(input_data + (2 * m - 1));
input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m - 1)); // input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m -
// 1));
in0 = input_buff_mid.val[0]; //
tmp0 = input_buff_mid.val[1]; // in0 = input_buff_mid.val[0];
tmp1 = vextq_f32(in0, zero, 1); // tmp0 = input_buff_mid.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
in2 = input_buff_bottom.val[0]; //
tmp2 = input_buff_bottom.val[1]; // in2 = input_buff_bottom.val[0];
tmp3 = vextq_f32(in2, zero, 1); // tmp2 = input_buff_bottom.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
out0 = vmulq_n_f32(in0, w10); //
out0 = vmlaq_n_f32(out0, tmp0, w11); // out0 = vmulq_n_f32(in0, w10);
out0 = vmlaq_n_f32(out0, tmp1, w12); // out0 = vmlaq_n_f32(out0, tmp0, w11);
out0 = vmlaq_n_f32(out0, in2, w20); // out0 = vmlaq_n_f32(out0, tmp1, w12);
out0 = vmlaq_n_f32(out0, tmp2, w21); // out0 = vmlaq_n_f32(out0, in2, w20);
out0 = vmlaq_n_f32(out0, tmp3, w22); // out0 = vmlaq_n_f32(out0, tmp2, w21);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); // out0 = vmlaq_n_f32(out0, tmp3, w22);
if (if_relu) { // out0 = vmlaq_f32(vnewbias, vnewscale, out0);
out0 = vmaxq_f32(out0, zero); // if (if_relu) {
} // out0 = vmaxq_f32(out0, zero);
vst1q_lane_f32(output_ptr, out0, 0); // }
vst1q_lane_f32(output_ptr + 1, out0, 1); // vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 2, out0, 2); // vst1q_lane_f32(output_ptr + 1, out0, 1);
} // vst1q_lane_f32(output_ptr + 2, out0, 2);
for (m = 1; m < output_width - 2; m += 3) { // }
} // for (m = 1; m < output_width - 2; m += 3) {
for (int j = m; j < output_width; j++) { // }
output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] * w11 + // for (int j = m; j < output_width; j++) {
input_data[2 * j + 1] * w12 + // output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] *
input_data[2 * j - 1 + input_width] * w20 + // w11 +
input_data[2 * j + input_width] * w21 + // input_data[2 * j + 1] * w12 +
input_data[2 * j + 1 + input_width] * w22; // input_data[2 * j - 1 + input_width] * w20 +
output_data[j] = newscale_data[c] * output_data[j] + newbias_data[c]; // input_data[2 * j + input_width] * w21 +
if (if_relu) { // input_data[2 * j + 1 + input_width] * w22;
output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; // output_data[j] = newscale_data[c] * output_data[j] +
} // newbias_data[c]; if (if_relu) {
} // output_data[j] = output_data[j] < 0 ? 0 : output_data[j];
// }
for (int i = 1; i < output_height; i += 1) { // }
for (int m = 1; m < output_width - 2; m += 3) { //
float *output_ptr = output_data + i * output_width + m; // for (int i = 1; i < output_height; i += 1) {
float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{}; // for (int m = 1; m < output_width - 2; m += 3) {
float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, // float *output_ptr = output_data + i * output_width + m;
tmp4, tmp5, out0; // float32x4x2_t input_buff_top{}, input_buff_mid{},
input_buff_top = // input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5,
vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m - 1)); // tmp0, tmp1, tmp2, tmp3,
input_buff_mid = // tmp4, tmp5, out0;
vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1)); // input_buff_top =
input_buff_bottom = // vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m -
vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m - 1)); // 1));
// input_buff_mid =
in0 = input_buff_top.val[0]; // vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1));
tmp0 = input_buff_top.val[1]; // input_buff_bottom =
tmp1 = vextq_f32(in0, zero, 1); // vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m -
// 1));
in2 = input_buff_mid.val[0]; //
tmp2 = input_buff_mid.val[1]; // in0 = input_buff_top.val[0];
tmp3 = vextq_f32(in2, zero, 1); // tmp0 = input_buff_top.val[1];
// tmp1 = vextq_f32(in0, zero, 1);
in4 = input_buff_bottom.val[0]; //
tmp4 = input_buff_bottom.val[1]; // in2 = input_buff_mid.val[0];
tmp5 = vextq_f32(in4, zero, 1); // tmp2 = input_buff_mid.val[1];
// tmp3 = vextq_f32(in2, zero, 1);
out0 = vmulq_n_f32(in0, w00); //
out0 = vmlaq_n_f32(out0, tmp0, w01); // in4 = input_buff_bottom.val[0];
out0 = vmlaq_n_f32(out0, tmp1, w02); // tmp4 = input_buff_bottom.val[1];
out0 = vmlaq_n_f32(out0, in2, w10); // tmp5 = vextq_f32(in4, zero, 1);
out0 = vmlaq_n_f32(out0, tmp2, w11); //
out0 = vmlaq_n_f32(out0, tmp3, w12); // out0 = vmulq_n_f32(in0, w00);
out0 = vmlaq_n_f32(out0, in4, w20); // out0 = vmlaq_n_f32(out0, tmp0, w01);
out0 = vmlaq_n_f32(out0, tmp4, w21); // out0 = vmlaq_n_f32(out0, tmp1, w02);
out0 = vmlaq_n_f32(out0, tmp5, w22); // out0 = vmlaq_n_f32(out0, in2, w10);
out0 = vmlaq_f32(vnewbias, vnewscale, out0); // out0 = vmlaq_n_f32(out0, tmp2, w11);
if (if_relu) { // out0 = vmlaq_n_f32(out0, tmp3, w12);
out0 = vmaxq_f32(out0, zero); // out0 = vmlaq_n_f32(out0, in4, w20);
} // out0 = vmlaq_n_f32(out0, tmp4, w21);
vst1q_lane_f32(output_ptr, out0, 0); // out0 = vmlaq_n_f32(out0, tmp5, w22);
vst1q_lane_f32(output_ptr + 1, out0, 1); // out0 = vmlaq_f32(vnewbias, vnewscale, out0);
vst1q_lane_f32(output_ptr + 2, out0, 2); // if (if_relu) {
} // out0 = vmaxq_f32(out0, zero);
int m; // }
for (m = 1; m < output_width - 2; m += 3) { // vst1q_lane_f32(output_ptr, out0, 0);
} // vst1q_lane_f32(output_ptr + 1, out0, 1);
for (int j = m; j < output_width; j++) { // vst1q_lane_f32(output_ptr + 2, out0, 2);
output_data[i * output_width + j] = // }
input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 + // int m;
input_data[(2 * i - 1) * input_width + 2 * j] * w01 + // for (m = 1; m < output_width - 2; m += 3) {
input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 + // }
input_data[(2 * i) * input_width + 2 * j - 1] * w10 + // for (int j = m; j < output_width; j++) {
input_data[(2 * i) * input_width + 2 * j] * w11 + // output_data[i * output_width + j] =
input_data[(2 * i) * input_width + 2 * j + 1] * w12 + // input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 +
input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 + // input_data[(2 * i - 1) * input_width + 2 * j] * w01 +
input_data[(2 * i + 1) * input_width + 2 * j] * w21 + // input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 +
input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22; // input_data[(2 * i) * input_width + 2 * j - 1] * w10 +
output_data[i * output_width + j] = // input_data[(2 * i) * input_width + 2 * j] * w11 +
newscale_data[c] * output_data[i * output_width + j] + // input_data[(2 * i) * input_width + 2 * j + 1] * w12 +
newbias_data[c]; // input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 +
if (if_relu) { // input_data[(2 * i + 1) * input_width + 2 * j] * w21 +
output_data[i * output_width + j] = // input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22;
output_data[i * output_width + j] < 0 // output_data[i * output_width + j] =
? 0 // newscale_data[c] * output_data[i * output_width + j] +
: output_data[i * output_width + j]; // newbias_data[c];
} // if (if_relu) {
} // output_data[i * output_width + j] =
} // output_data[i * output_width + j] < 0
output_data[0] = input_data[0] * w11 + input_data[1] * w12 + // ? 0
input_data[input_height] * w21 + // : output_data[i * output_width + j];
input_data[input_height + 1] * w22; // }
// }
output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c]; // }
if (if_relu) { // output_data[0] = input_data[0] * w11 + input_data[1] * w12 +
output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; // input_data[input_height] * w21 +
} // input_data[input_height + 1] * w22;
for (int i = 1; i < output_height; i++) { //
output_data[i * output_width] = // output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c];
input_data[(2 * i - 1) * input_width] * w01 + // if (if_relu) {
input_data[(2 * i - 1) * input_width + 1] * w02 + // output_data[0] = output_data[0] < 0 ? 0 : output_data[0];
input_data[(2 * i) * input_width] * w11 + // }
input_data[(2 * i) * input_width + 1] * w12 + // for (int i = 1; i < output_height; i++) {
input_data[(2 * i + 1) * input_width] * w21 + // output_data[i * output_width] =
input_data[(2 * i + 1) * input_width + 1] * w22; // input_data[(2 * i - 1) * input_width] * w01 +
// input_data[(2 * i - 1) * input_width + 1] * w02 +
output_data[i * output_width] = // input_data[(2 * i) * input_width] * w11 +
newscale_data[c] * output_data[i * output_width] + newbias_data[c]; // input_data[(2 * i) * input_width + 1] * w12 +
if (if_relu) { // input_data[(2 * i + 1) * input_width] * w21 +
output_data[i * output_width] = output_data[i * output_width] < 0 // input_data[(2 * i + 1) * input_width + 1] * w22;
? 0 //
: output_data[i * output_width]; // output_data[i * output_width] =
} // newscale_data[c] * output_data[i * output_width] +
} // newbias_data[c];
} // if (if_relu) {
} // output_data[i * output_width] = output_data[i * output_width] < 0
// ? 0
#else // : output_data[i *
// output_width];
// }
// }
// }
// }
//
//#else
const float *input_data = input->data<float>(); const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>(); const float *filter_data = filter->data<float>();
...@@ -1646,9 +1653,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1646,9 +1653,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
const int in_h = static_cast<int>(input->dims()[2]); const int in_h = static_cast<int>(input->dims()[2]);
const int in_w = static_cast<int>(input->dims()[3]); const int in_w = static_cast<int>(input->dims()[3]);
const int out_h = static_cast<int>(output->dims()[2]); const int out_h = static_cast<int>(output->dims()[2]);
...@@ -1660,22 +1664,22 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1660,22 +1664,22 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0;
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int c = static_cast<int>(input->dims()[1]); const int c = static_cast<int>(input->dims()[1]);
const int w_times = (out_w - 2) / 3;
float32x4_t zero = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) {
#pragma omp parallel for
for (int j = 0; j < c; j++) {
const float *input_row_ptr; const float *input_row_ptr;
float *output_row_ptr; float *output_row_ptr;
const int w_times = (out_w - 2) / 3;
float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1];
float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; float32x4_t elewise_res0, elewise_res1, elewise_res2, res3;
int out2in_mid; int out2in_mid;
float32x4_t zero = vdupq_n_f32(0.0); float32x4_t vnewbias = vdupq_n_f32(0.0);
for (int b = batch_size; b > 0; --b) { float32x4_t vnewscale = vdupq_n_f32(1.0);
const float *filter_data_tmp = filter_data;
for (int j = 0; j < c; ++j) {
auto output_data_tmp = output_data + j * out_h * out_w; auto output_data_tmp = output_data + j * out_h * out_w;
auto input_data_tmp = input_data + j * in_h * in_w; auto input_data_tmp = input_data + j * in_h * in_w;
auto input_const = input_data_tmp; auto input_const = input_data_tmp;
const float *filter_data_tmp = filter_data + 9 * j;
vnewbias = vdupq_n_f32(newbias_data[j]); vnewbias = vdupq_n_f32(newbias_data[j]);
vnewscale = vdupq_n_f32(newscale_data[j]); vnewscale = vdupq_n_f32(newscale_data[j]);
...@@ -1726,7 +1730,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1726,7 +1730,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
if (if_relu) { if (if_relu) {
res3 = vmaxq_f32(res3, zero); res3 = vmaxq_f32(res3, zero);
} }
vst1q_f32(output_row_ptr, res3); vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
vst1q_lane_f32(output_row_ptr + 2, res3, 2);
input_row_ptr += 6; input_row_ptr += 6;
output_row_ptr += 3; output_row_ptr += 3;
...@@ -1765,7 +1771,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1765,7 +1771,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
res3 = vmaxq_f32(res3, zero); res3 = vmaxq_f32(res3, zero);
} }
if ((w4 != w_times)) { if ((w4 != w_times)) {
vst1q_f32(output_row_ptr, res3); vst1q_lane_f32(output_row_ptr, res3, 0);
vst1q_lane_f32(output_row_ptr + 1, res3, 1);
vst1q_lane_f32(output_row_ptr + 2, res3, 2);
} else { } else {
if (out_l - 2 - w_times * 3 == 1) { if (out_l - 2 - w_times * 3 == 1) {
vst1q_lane_f32(output_row_ptr, res3, 0); vst1q_lane_f32(output_row_ptr, res3, 0);
...@@ -1865,12 +1873,11 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1865,12 +1873,11 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
: output_data_tmp[i * out_l + out_l - 1]; : output_data_tmp[i * out_l + out_l - 1];
} }
} }
filter_data_tmp += 9;
} }
input_data += inhxw * c; input_data += inhxw * c;
output_data += outhxw * c; output_data += outhxw * c;
} }
#endif //#endif
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册