diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index 1ca3797882807ae5f12b16483d90e359da6dfb99..c93278a661f72152debcef7066bdd751bccc5b4e 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -699,7 +699,7 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, : output_data[(output_height - 1) * output_width + j]; } } - #pragma omp parallel for +#pragma omp parallel for for (int i = 1; i < output_height - 1; i++) { for (int m = 1; (m + 3) < output_width - 1; m = m + 4) { float *output_ptr = output_data + i * output_width + m; @@ -1466,6 +1466,7 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, Tensor *output, const Tensor *new_scale, const Tensor *new_bias, bool if_relu) { #if __ARM_NEON +#ifdef _OPENMP const float *input_data = input->data(); const float *filter_data = filter->data(); float *output_data = output->data(); @@ -1642,251 +1643,239 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, } } - // const float *input_data = input->data(); - // const float *filter_data = filter->data(); - // float *output_data = output->data(); - // const float *newscale_data = new_scale->data(); - // const float *newbias_data = new_bias->data(); - // - // float32x4_t vnewbias = vdupq_n_f32(0.0); - // float32x4_t vnewscale = vdupq_n_f32(1.0); - // - // const int in_h = static_cast(input->dims()[2]); - // const int in_w = static_cast(input->dims()[3]); - // const int out_h = static_cast(output->dims()[2]); - // const int out_w = static_cast(output->dims()[3]); - // const int out_l = out_h; - // const int in_l = in_h; - // const int inhxw = in_h * in_w; - // const int outhxw = out_h * out_w; - // const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; - // const int batch_size = static_cast(input->dims()[0]); - // const int c = static_cast(input->dims()[1]); - // const float *input_row_ptr; - // float *output_row_ptr; - // - // const int w_times = (out_w - 2) / 3; - // - // float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; - // float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; - // int out2in_mid; - // float32x4_t zero = vdupq_n_f32(0.0); - // for (int b = batch_size; b > 0; --b) { - // const float *filter_data_tmp = filter_data; - // for (int j = 0; j < c; ++j) { - // auto output_data_tmp = output_data + j * out_h * out_w; - // auto input_data_tmp = input_data + j * in_h * in_w; - // auto input_const = input_data_tmp; - // - // vnewbias = vdupq_n_f32(newbias_data[j]); - // vnewscale = vdupq_n_f32(newscale_data[j]); - // - // float w00 = filter_data_tmp[0]; - // float w01 = filter_data_tmp[1]; - // float w02 = filter_data_tmp[2]; - // float w10 = filter_data_tmp[3]; - // float w11 = filter_data_tmp[4]; - // float w12 = filter_data_tmp[5]; - // float w20 = filter_data_tmp[6]; - // float w21 = filter_data_tmp[7]; - // float w22 = filter_data_tmp[8]; - // - // int h_mid = 0; - // - // for (; h_mid < out_h - 1; h_mid++) { - // input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - // output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - // - // for (int w4 = 0; w4 < w_times + 1; w4++) { - // if (h_mid == 0) { - // elewise_res1 = zero; - // elewise_res0 = zero; - // elewise_res2 = zero; - // } else { - // elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - // elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - // elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - // } - // input_buff_mid = vld2q_f32(input_row_ptr); - // input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - // - // elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], - // w11); elewise_res0 = vmlaq_n_f32(elewise_res0, - // input_buff_mid.val[0], w10); elewise_res2 = - // vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - // - // elewise_res1 = - // vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], - // w21); - // elewise_res0 = - // vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], - // w20); - // elewise_res2 = - // vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], - // w22); - // - // res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - // vaddq_f32(elewise_res0, elewise_res1)); - // res3 = vmlaq_f32(vnewbias, vnewscale, res3); - // - // if (if_relu) { - // res3 = vmaxq_f32(res3, zero); - // } - // vst1q_f32(output_row_ptr, res3); - // - // input_row_ptr += 6; - // output_row_ptr += 3; - // } - // } - // clock(); - // - // input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - // output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - // - // for (int w4 = 0; w4 < w_times + 1; w4++) { - // elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - // elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - // elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - // - // input_buff_mid = vld2q_f32(input_row_ptr); - // input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - // - // elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], - // w11); elewise_res0 = vmlaq_n_f32(elewise_res0, - // input_buff_mid.val[0], w10); elewise_res2 = - // vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - // - // if (!if_pad) { - // elewise_res1 = - // vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], - // w21); - // elewise_res0 = - // vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], - // w20); - // elewise_res2 = - // vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], - // w22); - // } - // res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - // vaddq_f32(elewise_res0, elewise_res1)); - // res3 = vmlaq_f32(vnewbias, vnewscale, res3); - // - // if (if_relu) { - // res3 = vmaxq_f32(res3, zero); - // } - // if ((w4 != w_times)) { - // vst1q_f32(output_row_ptr, res3); - // } else { - // if (out_l - 2 - w_times * 3 == 1) { - // vst1q_lane_f32(output_row_ptr, res3, 0); - // } else if (out_l - 2 - w_times * 3 == 2) { - // vst1q_lane_f32(output_row_ptr, res3, 0); - // vst1q_lane_f32(output_row_ptr + 1, res3, 1); - // } - // } - // input_row_ptr += 6; - // output_row_ptr += 3; - // } - // - // output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + - // input_const[in_l] * w21 + - // input_const[in_l + 1] * w22; - // - // out2in_mid = (out_l - 1) * 2; - // output_data_tmp[out_l - 1] = - // w10 * input_const[out2in_mid - 1] + w11 * - // input_const[out2in_mid] + w20 * input_const[out2in_mid + in_w - - // 1] + w21 * input_const[out2in_mid + in_w] + (1 - if_pad) * (w12 - // * input_const[out2in_mid + 1] + - // w22 * input_const[out2in_mid + in_w + 1]); - // - // out2in_mid = (out_l - 1) * 2 * in_w; - // - // output_data_tmp[out_l * (out_l - 1)] = - // w01 * input_const[out2in_mid - in_w] + - // w02 * input_const[out2in_mid - in_w + 1] + - // w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + - // 1] + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] + - // w22 * input_const[out2in_mid + in_w + 1]); - // out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2; - // - // output_data_tmp[out_l * out_l - 1] = - // w00 * input_const[out2in_mid - in_w - 1] + - // w01 * input_const[out2in_mid - in_w] + - // w10 * input_const[out2in_mid - 1] + w11 * - // input_const[out2in_mid] + (1 - if_pad) * (w20 * - // input_const[out2in_mid + in_w - 1] + - // w21 * input_const[out2in_mid + in_w] + - // w02 * input_const[out2in_mid - in_w + 1] + - // w12 * input_const[out2in_mid + 1] + - // w22 * input_const[out2in_mid + in_w + 1]); - // output_data_tmp[0] = - // output_data_tmp[0] * newscale_data[j] + newbias_data[j]; - // output_data_tmp[out_l - 1] = - // output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j]; - // output_data_tmp[out_l * (out_l - 1)] = - // output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] + - // newbias_data[j]; - // output_data_tmp[out_l * out_l - 1] = - // output_data_tmp[out_l * out_l - 1] * newscale_data[j] + - // newbias_data[j]; - // if (if_relu) { - // output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : - // output_data_tmp[0]; output_data_tmp[out_l - 1] = - // output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - - // 1]; - // output_data_tmp[out_l * (out_l - 1)] = - // output_data_tmp[out_l * (out_l - 1)] < 0 - // ? 0 - // : output_data_tmp[out_l * (out_l - 1)]; - // output_data_tmp[out_l * out_l - 1] = - // output_data_tmp[out_l * out_l - 1] < 0 - // ? 0 - // : output_data_tmp[out_l * out_l - 1]; - // } - // for (int i = 1; i < out_h - 1; i++) { - // out2in_mid = i * 2 * in_w; - // output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] - // + - // w02 * input_const[out2in_mid - in_w + - // 1] + w11 * input_const[out2in_mid] + - // w12 * input_const[out2in_mid + 1] + - // w21 * input_const[out2in_mid + in_w] - // + w22 * input_const[out2in_mid + in_w - // + 1]; - // - // out2in_mid = i * 2 * in_w + (out_l - 1) * 2; - // output_data_tmp[i * out_l + out_l - 1] = - // w00 * input_const[out2in_mid - in_w - 1] + - // w01 * input_const[out2in_mid - in_w] + - // w10 * input_const[out2in_mid - 1] + w11 * - // input_const[out2in_mid] + w20 * input_const[out2in_mid + in_w - // - 1] + w21 * input_const[out2in_mid + in_w] + (1 - if_pad) * - // (w02 * input_const[out2in_mid - in_w + 1] + - // w12 * input_const[out2in_mid + 1] + - // w22 * input_const[out2in_mid + in_w + 1]); - // output_data_tmp[i * out_l] = - // output_data_tmp[i * out_l] * newscale_data[j] + - // newbias_data[j]; - // output_data_tmp[i * out_l + out_l - 1] = - // output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] + - // newbias_data[j]; - // if (if_relu) { - // output_data_tmp[i * out_l] = - // output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * - // out_l]; - // output_data_tmp[i * out_l + out_l - 1] = - // output_data_tmp[i * out_l + out_l - 1] < 0 - // ? 0 - // : output_data_tmp[i * out_l + out_l - 1]; - // } - // } - // filter_data_tmp += 9; - // } - // input_data += inhxw * c; - // output_data += outhxw * c; - // } +#else + + const float *input_data = input->data(); + const float *filter_data = filter->data(); + float *output_data = output->data(); + const float *newscale_data = new_scale->data(); + const float *newbias_data = new_bias->data(); + + float32x4_t vnewbias = vdupq_n_f32(0.0); + float32x4_t vnewscale = vdupq_n_f32(1.0); + + const int in_h = static_cast(input->dims()[2]); + const int in_w = static_cast(input->dims()[3]); + const int out_h = static_cast(output->dims()[2]); + const int out_w = static_cast(output->dims()[3]); + const int out_l = out_h; + const int in_l = in_h; + const int inhxw = in_h * in_w; + const int outhxw = out_h * out_w; + const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const float *input_row_ptr; + float *output_row_ptr; + + const int w_times = (out_w - 2) / 3; + + float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; + float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; + int out2in_mid; + float32x4_t zero = vdupq_n_f32(0.0); + for (int b = batch_size; b > 0; --b) { + const float *filter_data_tmp = filter_data; + for (int j = 0; j < c; ++j) { + auto output_data_tmp = output_data + j * out_h * out_w; + auto input_data_tmp = input_data + j * in_h * in_w; + auto input_const = input_data_tmp; + + vnewbias = vdupq_n_f32(newbias_data[j]); + vnewscale = vdupq_n_f32(newscale_data[j]); + + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + int h_mid = 0; + + for (; h_mid < out_h - 1; h_mid++) { + input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; + output_row_ptr = output_data_tmp + 1 + h_mid * out_w; + + for (int w4 = 0; w4 < w_times + 1; w4++) { + if (h_mid == 0) { + elewise_res1 = zero; + elewise_res0 = zero; + elewise_res2 = zero; + } else { + elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); + elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); + elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); + } + input_buff_mid = vld2q_f32(input_row_ptr); + input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); + + elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); + elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); + elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); + + elewise_res1 = + vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); + elewise_res0 = + vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); + elewise_res2 = + vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); + + res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), + vaddq_f32(elewise_res0, elewise_res1)); + res3 = vmlaq_f32(vnewbias, vnewscale, res3); + + if (if_relu) { + res3 = vmaxq_f32(res3, zero); + } + vst1q_f32(output_row_ptr, res3); + + input_row_ptr += 6; + output_row_ptr += 3; + } + } + clock(); + + input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; + output_row_ptr = output_data_tmp + 1 + h_mid * out_w; + + for (int w4 = 0; w4 < w_times + 1; w4++) { + elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); + elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); + elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); + + input_buff_mid = vld2q_f32(input_row_ptr); + input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); + + elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); + elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); + elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); + + if (!if_pad) { + elewise_res1 = + vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); + elewise_res0 = + vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); + elewise_res2 = + vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); + } + res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), + vaddq_f32(elewise_res0, elewise_res1)); + res3 = vmlaq_f32(vnewbias, vnewscale, res3); + + if (if_relu) { + res3 = vmaxq_f32(res3, zero); + } + if ((w4 != w_times)) { + vst1q_f32(output_row_ptr, res3); + } else { + if (out_l - 2 - w_times * 3 == 1) { + vst1q_lane_f32(output_row_ptr, res3, 0); + } else if (out_l - 2 - w_times * 3 == 2) { + vst1q_lane_f32(output_row_ptr, res3, 0); + vst1q_lane_f32(output_row_ptr + 1, res3, 1); + } + } + input_row_ptr += 6; + output_row_ptr += 3; + } + + output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + + input_const[in_l] * w21 + + input_const[in_l + 1] * w22; + + out2in_mid = (out_l - 1) * 2; + output_data_tmp[out_l - 1] = + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + (1 - if_pad) * (w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + + out2in_mid = (out_l - 1) * 2 * in_w; + + output_data_tmp[out_l * (out_l - 1)] = + w01 * input_const[out2in_mid - in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] + + w22 * input_const[out2in_mid + in_w + 1]); + out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2; + + output_data_tmp[out_l * out_l - 1] = + w00 * input_const[out2in_mid - in_w - 1] + + w01 * input_const[out2in_mid - in_w] + + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + (1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + output_data_tmp[0] = + output_data_tmp[0] * newscale_data[j] + newbias_data[j]; + output_data_tmp[out_l - 1] = + output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j]; + output_data_tmp[out_l * (out_l - 1)] = + output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] + + newbias_data[j]; + output_data_tmp[out_l * out_l - 1] = + output_data_tmp[out_l * out_l - 1] * newscale_data[j] + + newbias_data[j]; + if (if_relu) { + output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0]; + output_data_tmp[out_l - 1] = + output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - 1]; + output_data_tmp[out_l * (out_l - 1)] = + output_data_tmp[out_l * (out_l - 1)] < 0 + ? 0 + : output_data_tmp[out_l * (out_l - 1)]; + output_data_tmp[out_l * out_l - 1] = + output_data_tmp[out_l * out_l - 1] < 0 + ? 0 + : output_data_tmp[out_l * out_l - 1]; + } + for (int i = 1; i < out_h - 1; i++) { + out2in_mid = i * 2 * in_w; + output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w11 * input_const[out2in_mid] + + w12 * input_const[out2in_mid + 1] + + w21 * input_const[out2in_mid + in_w] + + w22 * input_const[out2in_mid + in_w + 1]; + out2in_mid = i * 2 * in_w + (out_l - 1) * 2; + output_data_tmp[i * out_l + out_l - 1] = + w00 * input_const[out2in_mid - in_w - 1] + + w01 * input_const[out2in_mid - in_w] + + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + (1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] + + w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + output_data_tmp[i * out_l] = + output_data_tmp[i * out_l] * newscale_data[j] + newbias_data[j]; + output_data_tmp[i * out_l + out_l - 1] = + output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] + + newbias_data[j]; + if (if_relu) { + output_data_tmp[i * out_l] = + output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * out_l]; + output_data_tmp[i * out_l + out_l - 1] = + output_data_tmp[i * out_l + out_l - 1] < 0 + ? 0 + : output_data_tmp[i * out_l + out_l - 1]; + } + } + filter_data_tmp += 9; + } + input_data += inhxw * c; + output_data += outhxw * c; + } +#endif #endif }