From 3cb0cb18665bba7c783e7a3057e015e582e22aea Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Thu, 27 Dec 2018 21:44:57 +0800 Subject: [PATCH] add openmp in depthwise_conv3x3_s1p1 --- src/operators/math/depthwise_conv3x3.cpp | 34 +++++++++--------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 3187c6c141..90edc3111b 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -253,34 +253,29 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input, framework::Tensor *output, framework::Tensor *bias, bool if_bias, bool if_relu) { #if __ARM_NEON - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *bias_data; - if (if_bias) { - bias_data = bias->data(); - } - - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); - // const int l = h; + const float *bias_data = bias->data(); const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); + const int h = static_cast(input->dims()[2]); + const int w = static_cast(input->dims()[3]); const int hxw = h * w; - float32x4_t vbias = vdupq_n_f32(0.0); + // const int l = h; // leftTop, rightTop, leftBottom, rightBottom - int lt = 0; - int rt = w - 1; - int lb = (h - 1) * w; - int rb = h * w - 1; + const int lt = 0; + const int rt = w - 1; + const int lb = (h - 1) * w; + const int rb = h * w - 1; float32x4_t zero = vdupq_n_f32(0.0); for (int b = 0; b < batch_size; ++b) { - const float *filter_data_tmp = filter_data; - +#pragma omp parallel for for (int j = 0; j < c; ++j) { + const float *filter_data_tmp = filter->data() + j * 9; + const float *input_data = input->data() + j * hxw; + float *output_data = output->mutable_data() + j * hxw; + float32x4_t vbias; if (if_bias) { vbias = vdupq_n_f32(bias_data[j]); } @@ -552,9 +547,6 @@ void DepthwiseConv3x3s1p1(const framework::Tensor *input, } } } - output_data += hxw; - input_data += hxw; - filter_data_tmp += 9; } } #endif -- GitLab