未验证 提交 3620852a 编写于 作者: S smilejames 提交者: GitHub

Merge pull request #764 from yangfei963158659/develop

optimize multithreading 3x3 s2 depthwise_conv
...@@ -1466,9 +1466,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1466,9 +1466,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const Tensor *new_bias, bool if_relu) { const Tensor *new_bias, bool if_relu) {
#if __ARM_NEON #if __ARM_NEON
#ifdef _OPENMP #ifdef _OPENMP
const float *input_data = input->data<float>();
const float *filter_data = filter->data<float>();
float *output_data = output->data<float>();
const float *newscale_data = new_scale->data<float>(); const float *newscale_data = new_scale->data<float>();
const float *newbias_data = new_bias->data<float>(); const float *newbias_data = new_bias->data<float>();
...@@ -1482,14 +1479,15 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1482,14 +1479,15 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
const int inhxw = input_height * input_width; const int inhxw = input_height * input_width;
const int outhxw = output_height * output_width; const int outhxw = output_height * output_width;
float32x4_t vnewbias = vdupq_n_f32(0.0);
float32x4_t vnewscale = vdupq_n_f32(1.0);
float32x4_t zero = vdupq_n_f32(0.0); float32x4_t zero = vdupq_n_f32(0.0);
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
filter_data = filter->data<float>(); #pragma omp parallel for
for (int c = 0; c < input_channel; c++) { for (int c = 0; c < input_channel; c++) {
vnewbias = vdupq_n_f32(newbias_data[c]); const float *filter_data = filter->data<float>() + c * 9;
vnewscale = vdupq_n_f32(newscale_data[c]); const float *input_data = input->data<float>() + c * inhxw;
float *output_data = output->data<float>() + c * outhxw;
float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]);
float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]);
float w00 = filter_data[0]; float w00 = filter_data[0];
float w01 = filter_data[1]; float w01 = filter_data[1];
...@@ -1527,7 +1525,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1527,7 +1525,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, zero); out0 = vmaxq_f32(out0, zero);
} }
vst1q_f32(output_ptr, out0); vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2);
} }
for (m = 1; m < output_width - 2; m += 3) { for (m = 1; m < output_width - 2; m += 3) {
} }
...@@ -1543,8 +1543,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1543,8 +1543,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
} }
} }
#pragma omp parallel for
for (int i = 1; i < output_height; i += 1) { for (int i = 1; i < output_height; i += 1) {
for (int m = 1; m < output_width - 2; m += 3) { for (int m = 1; m < output_width - 2; m += 3) {
float *output_ptr = output_data + i * output_width + m; float *output_ptr = output_data + i * output_width + m;
...@@ -1583,7 +1581,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1583,7 +1581,9 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
if (if_relu) { if (if_relu) {
out0 = vmaxq_f32(out0, zero); out0 = vmaxq_f32(out0, zero);
} }
vst1q_f32(output_ptr, out0); vst1q_lane_f32(output_ptr, out0, 0);
vst1q_lane_f32(output_ptr + 1, out0, 1);
vst1q_lane_f32(output_ptr + 2, out0, 2);
} }
int m; int m;
for (m = 1; m < output_width - 2; m += 3) { for (m = 1; m < output_width - 2; m += 3) {
...@@ -1635,10 +1635,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, ...@@ -1635,10 +1635,6 @@ void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter,
: output_data[i * output_width]; : output_data[i * output_width];
} }
} }
input_data = input_data + inhxw;
output_data = output_data + outhxw;
filter_data = filter_data + 9;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册