From 7dca8ab1f6ab74208255c970689de3ea1746b377 Mon Sep 17 00:00:00 2001 From: eclipsess Date: Mon, 9 Jul 2018 16:54:43 +0800 Subject: [PATCH] dw3x3s2v2 and dw3x3s2bnreluv2 --- .../central-arm-func/conv_add_arm_func.h | 26 +- .../central-arm-func/conv_add_bn_relu_func.h | 9 +- .../depthwise_conv_arm_func.h | 8 +- src/operators/math/depthwise_conv_3x3.cpp | 439 ++++++++++++++++++ src/operators/math/depthwise_conv_3x3.h | 5 + 5 files changed, 479 insertions(+), 8 deletions(-) diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index ed6dc46a90..41ea0d09e7 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -124,9 +124,29 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3) { - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), param.Bias(), param.Output(), true); + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + // Tensor in,out,filter; + // auto inptr = in.mutable_data({1,2,10,10}); + // auto filterptr = filter.mutable_data({2,1,3,3}); + // auto outputptr= out.mutable_data({1,2,5,5}); + // for(int i = 0; i < in.numel(); ++i) + // { + // inptr[i] = i; + // } + // for (int i = 0; i < filter.numel(); ++i) + // { + // filterptr[i] = i; + // } + // math::DepthwiseConv3x3(param.Input(), param.Strides(), + // param.Paddings(), + // param.Filter(), param.Bias(), param.Output(), + // false); + // math::DepthwiseConv3x3(&in, param.Strides(), param.Paddings(), + // &filter, param.Bias(), &out, false); + + math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), + *param.Bias(), true); + } else { ConvAddBasic(param); } diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h index f184a59a3a..e8aed3fd7d 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_func.h @@ -138,9 +138,12 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), - param.Output(), param.NewScale(), - param.NewBias(), 1); + // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), + // param.Output(), param.NewScale(), + // param.NewBias(), 1); + math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(), + param.Output(), param.NewScale(), + param.NewBias(), true); } else { ConvAddBNReluBasic(param); } diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 885f2051f6..60b09df597 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -37,8 +37,12 @@ void DepthwiseConvCompute(const ConvParam ¶m) { param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), &Bias, param.Output(), false); + // math::DepthwiseConv3x3(param.Input(), param.Strides(), + // param.Paddings(), + // param.Filter(), &Bias, param.Output(), false); + math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), param.Output(), + Bias, false); + } else { ConvBasic(param); } diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index d1faa1fd64..33ad76b96a 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "operators/math/depthwise_conv_3x3.h" #include #include +#define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) namespace paddle_mobile { namespace operators { @@ -1010,6 +1011,444 @@ void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, output_data += output_batch_stride; } } + +void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, + Tensor *output, Tensor bias, bool if_bias) { + const float *input_data = input->data(); + const float *filter_data = filter->data(); + float *output_data = output->data(); + const float *bias_data = bias.data(); + + const int in_h = static_cast(input->dims()[2]); + const int in_w = static_cast(input->dims()[3]); + const int out_h = static_cast(output->dims()[2]); + const int out_w = static_cast(output->dims()[3]); + const int out_l = out_h; + const int in_l = in_h; + const int inhxw = in_h * in_w; + const int outhxw = out_h * out_w; + const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const float *input_row_ptr; + float *output_row_ptr; + + const int w_times = (out_w - 2) / 3; + + float32x4_t vbias = vdupq_n_f32(0.0); + + float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1], + input_buff_top[w_times + 1]; + float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; + int out2in_mid; + float32x4_t zero = vdupq_n_f32(0.0); + for (int b = batch_size; b > 0; --b) { + const float *filter_data_tmp = filter_data; + for (int j = 0; j < c; ++j) { + auto output_data_tmp = output_data + j * out_h * out_w; + auto input_data_tmp = input_data + j * in_h * in_w; + auto input_const = input_data_tmp; + + if (if_bias) { + vbias = vdupq_n_f32(bias_data[j]); + } + + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + int h_mid = 0; + + for (; h_mid < out_h - 1; h_mid++) { + input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; + output_row_ptr = output_data_tmp + 1 + h_mid * out_w; + + for (int w4 = 0; w4 < w_times + 1; w4++) { + if (h_mid == 0) { + elewise_res1 = zero; + elewise_res0 = zero; + elewise_res2 = zero; + } else { + elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); + elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); + elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); + } + input_buff_mid = vld2q_f32(input_row_ptr); + input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); + + elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); + elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); + elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); + + elewise_res1 = + vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); + elewise_res0 = + vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); + elewise_res2 = + vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); + + res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), + vaddq_f32(elewise_res0, elewise_res1)); + res3 = vaddq_f32(res3, vbias); + vst1q_f32(output_row_ptr, res3); + + input_row_ptr += 6; + output_row_ptr += 3; + } + } + clock(); + + input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; + output_row_ptr = output_data_tmp + 1 + h_mid * out_w; + + for (int w4 = 0; w4 < w_times + 1; w4++) { + elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); + elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); + elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); + + input_buff_mid = vld2q_f32(input_row_ptr); + input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); + + elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); + elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); + elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); + + if (!if_pad) { + elewise_res1 = + vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); + elewise_res0 = + vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); + elewise_res2 = + vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); + } + res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), + vaddq_f32(elewise_res0, elewise_res1)); + res3 = vaddq_f32(res3, vbias); + + if ((w4 != w_times)) { + vst1q_f32(output_row_ptr, res3); + } else { + if (out_l - 2 - w_times * 3 == 1) { + vst1q_lane_f32(output_row_ptr, res3, 0); + } else if (out_l - 2 - w_times * 3 == 2) { + vst1q_lane_f32(output_row_ptr, res3, 0); + vst1q_lane_f32(output_row_ptr + 1, res3, 1); + } + } + input_row_ptr += 6; + output_row_ptr += 3; + } + + output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + + input_const[in_l] * w21 + + input_const[in_l + 1] * w22; + + out2in_mid = (out_l - 1) * 2; + output_data_tmp[out_l - 1] = + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + (1 - if_pad) * (w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + + out2in_mid = (out_l - 1) * 2 * in_w; + + output_data_tmp[out_l * (out_l - 1)] = + w01 * input_const[out2in_mid - in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] + + w22 * input_const[out2in_mid + in_w + 1]); + out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2; + + output_data_tmp[out_l * out_l - 1] = + w00 * input_const[out2in_mid - in_w - 1] + + w01 * input_const[out2in_mid - in_w] + + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + (1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + if (if_bias) { + output_data_tmp[0] += bias_data[j]; + output_data_tmp[out_l - 1] += bias_data[j]; + output_data_tmp[out_l * (out_l - 1)] += bias_data[j]; + output_data_tmp[out_l * out_l - 1] += bias_data[j]; + } + for (int i = 1; i < out_h - 1; i++) { + out2in_mid = i * 2 * in_w; + output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w11 * input_const[out2in_mid] + + w12 * input_const[out2in_mid + 1] + + w21 * input_const[out2in_mid + in_w] + + w22 * input_const[out2in_mid + in_w + 1]; + + out2in_mid = i * 2 * in_w + (out_l - 1) * 2; + output_data_tmp[i * out_l + out_l - 1] = + w00 * input_const[out2in_mid - in_w - 1] + + w01 * input_const[out2in_mid - in_w] + + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + (1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] + + w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + if (if_bias) { + output_data_tmp[i * out_l] += bias_data[j]; + output_data_tmp[i * out_l + out_l - 1] += bias_data[j]; + } + } + filter_data_tmp += 9; + } + input_data += inhxw * c; + output_data += outhxw * c; + } +} + +void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, + Tensor *output, const Tensor *new_scale, + const Tensor *new_bias, bool if_relu) { + const float *input_data = input->data(); + const float *filter_data = filter->data(); + float *output_data = output->data(); + const float *newscale_data = new_scale->data(); + const float *newbias_data = new_bias->data(); + + float32x4_t vnewbias = vdupq_n_f32(0.0); + float32x4_t vnewscale = vdupq_n_f32(1.0); + + const int in_h = static_cast(input->dims()[2]); + const int in_w = static_cast(input->dims()[3]); + const int out_h = static_cast(output->dims()[2]); + const int out_w = static_cast(output->dims()[3]); + const int out_l = out_h; + const int in_l = in_h; + const int inhxw = in_h * in_w; + const int outhxw = out_h * out_w; + const int if_pad = in_l - 1 == (out_l - 1) * 2 ? 1 : 0; + const int batch_size = static_cast(input->dims()[0]); + const int c = static_cast(input->dims()[1]); + const float *input_row_ptr; + float *output_row_ptr; + + const int w_times = (out_w - 2) / 3; + + float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1], + input_buff_top[w_times + 1]; + float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; + int out2in_mid; + float32x4_t zero = vdupq_n_f32(0.0); + for (int b = batch_size; b > 0; --b) { + const float *filter_data_tmp = filter_data; + for (int j = 0; j < c; ++j) { + auto output_data_tmp = output_data + j * out_h * out_w; + auto input_data_tmp = input_data + j * in_h * in_w; + auto input_const = input_data_tmp; + + vnewbias = vdupq_n_f32(newbias_data[j]); + vnewscale = vdupq_n_f32(newscale_data[j]); + + float w00 = filter_data_tmp[0]; + float w01 = filter_data_tmp[1]; + float w02 = filter_data_tmp[2]; + float w10 = filter_data_tmp[3]; + float w11 = filter_data_tmp[4]; + float w12 = filter_data_tmp[5]; + float w20 = filter_data_tmp[6]; + float w21 = filter_data_tmp[7]; + float w22 = filter_data_tmp[8]; + + int h_mid = 0; + + for (; h_mid < out_h - 1; h_mid++) { + input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; + output_row_ptr = output_data_tmp + 1 + h_mid * out_w; + + for (int w4 = 0; w4 < w_times + 1; w4++) { + if (h_mid == 0) { + elewise_res1 = zero; + elewise_res0 = zero; + elewise_res2 = zero; + } else { + elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); + elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); + elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); + } + input_buff_mid = vld2q_f32(input_row_ptr); + input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); + + elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); + elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); + elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); + + elewise_res1 = + vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); + elewise_res0 = + vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); + elewise_res2 = + vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); + + res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), + vaddq_f32(elewise_res0, elewise_res1)); + res3 = vmlaq_f32(vnewbias, vnewscale, res3); + + if (if_relu) { + res3 = vmaxq_f32(res3, zero); + } + vst1q_f32(output_row_ptr, res3); + + input_row_ptr += 6; + output_row_ptr += 3; + } + } + clock(); + + input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; + output_row_ptr = output_data_tmp + 1 + h_mid * out_w; + + for (int w4 = 0; w4 < w_times + 1; w4++) { + elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); + elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); + elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); + + input_buff_mid = vld2q_f32(input_row_ptr); + input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); + + elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); + elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); + elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); + + if (!if_pad) { + elewise_res1 = + vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); + elewise_res0 = + vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); + elewise_res2 = + vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); + } + res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), + vaddq_f32(elewise_res0, elewise_res1)); + res3 = vmlaq_f32(vnewbias, vnewscale, res3); + + if (if_relu) { + res3 = vmaxq_f32(res3, zero); + } + if ((w4 != w_times)) { + vst1q_f32(output_row_ptr, res3); + } else { + if (out_l - 2 - w_times * 3 == 1) { + vst1q_lane_f32(output_row_ptr, res3, 0); + } else if (out_l - 2 - w_times * 3 == 2) { + vst1q_lane_f32(output_row_ptr, res3, 0); + vst1q_lane_f32(output_row_ptr + 1, res3, 1); + } + } + input_row_ptr += 6; + output_row_ptr += 3; + } + + output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + + input_const[in_l] * w21 + + input_const[in_l + 1] * w22; + + out2in_mid = (out_l - 1) * 2; + output_data_tmp[out_l - 1] = + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + (1 - if_pad) * (w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + + out2in_mid = (out_l - 1) * 2 * in_w; + + output_data_tmp[out_l * (out_l - 1)] = + w01 * input_const[out2in_mid - in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + + (1 - if_pad) * (w21 * input_const[out2in_mid + in_w] + + w22 * input_const[out2in_mid + in_w + 1]); + out2in_mid = (out_l - 1) * 2 * in_w + (out_l - 1) * 2; + + output_data_tmp[out_l * out_l - 1] = + w00 * input_const[out2in_mid - in_w - 1] + + w01 * input_const[out2in_mid - in_w] + + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + (1 - if_pad) * (w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + output_data_tmp[0] = + output_data_tmp[0] * newscale_data[j] + newbias_data[j]; + output_data_tmp[out_l - 1] = + output_data_tmp[out_l - 1] * newscale_data[j] + newbias_data[j]; + output_data_tmp[out_l * (out_l - 1)] = + output_data_tmp[out_l * (out_l - 1)] * newscale_data[j] + + newbias_data[j]; + output_data_tmp[out_l * out_l - 1] = + output_data_tmp[out_l * out_l - 1] * newscale_data[j] + + newbias_data[j]; + if (if_relu) { + output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0]; + output_data_tmp[out_l - 1] = + output_data_tmp[out_l - 1] < 0 ? 0 : output_data_tmp[out_l - 1]; + output_data_tmp[out_l * (out_l - 1)] = + output_data_tmp[out_l * (out_l - 1)] < 0 + ? 0 + : output_data_tmp[out_l * (out_l - 1)]; + output_data_tmp[out_l * out_l - 1] = + output_data_tmp[out_l * out_l - 1] < 0 + ? 0 + : output_data_tmp[out_l * out_l - 1]; + } + for (int i = 1; i < out_h - 1; i++) { + out2in_mid = i * 2 * in_w; + output_data_tmp[i * out_l] = w01 * input_const[out2in_mid - in_w] + + w02 * input_const[out2in_mid - in_w + 1] + + w11 * input_const[out2in_mid] + + w12 * input_const[out2in_mid + 1] + + w21 * input_const[out2in_mid + in_w] + + w22 * input_const[out2in_mid + in_w + 1]; + + out2in_mid = i * 2 * in_w + (out_l - 1) * 2; + output_data_tmp[i * out_l + out_l - 1] = + w00 * input_const[out2in_mid - in_w - 1] + + w01 * input_const[out2in_mid - in_w] + + w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + + w20 * input_const[out2in_mid + in_w - 1] + + w21 * input_const[out2in_mid + in_w] + + (1 - if_pad) * (w02 * input_const[out2in_mid - in_w + 1] + + w12 * input_const[out2in_mid + 1] + + w22 * input_const[out2in_mid + in_w + 1]); + output_data_tmp[i * out_l] = + output_data_tmp[i * out_l] * newscale_data[j] + newbias_data[j]; + output_data_tmp[i * out_l + out_l - 1] = + output_data_tmp[i * out_l + out_l - 1] * newscale_data[j] + + newbias_data[j]; + if (if_relu) { + output_data_tmp[i * out_l] = + output_data_tmp[i * out_l] < 0 ? 0 : output_data_tmp[i * out_l]; + output_data_tmp[i * out_l + out_l - 1] = + output_data_tmp[i * out_l + out_l - 1] < 0 + ? 0 + : output_data_tmp[i * out_l + out_l - 1]; + } + } + filter_data_tmp += 9; + } + input_data += inhxw * c; + output_data += outhxw * c; + } +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv_3x3.h b/src/operators/math/depthwise_conv_3x3.h index b5103a53ad..60e979648f 100644 --- a/src/operators/math/depthwise_conv_3x3.h +++ b/src/operators/math/depthwise_conv_3x3.h @@ -38,6 +38,11 @@ void DepthwiseConvAddBNRelu3x3s1p1(const Tensor *input, const Tensor *filter, void DepthwiseConvAddBNRelu3x3s2p1(const Tensor *input, const Tensor *filter, Tensor *output, const Tensor *new_scale, const Tensor *new_bias, bool if_relu); +void DepthwiseConv3x3s2p1v2(const Tensor *input, const Tensor *filter, + Tensor *output, Tensor bias, bool if_bias); +void DepthwiseConvAddBNRelu3x3s2p1v2(const Tensor *input, const Tensor *filter, + Tensor *output, const Tensor *new_scale, + const Tensor *new_bias, bool if_relu); } // namespace math } // namespace operators } // namespace paddle_mobile -- GitLab