diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index 0e8ce8d640053eba5f92a5ee20776e85589981f9..6e5b467c9285c9a752b201c253080990d413893d 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -118,8 +118,7 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && - param.Input()->dims()[2] == param.Input()->dims()[3]) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), param.Bias(), true); } else if (param.Groups() == param.Input()->dims()[1] && diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 25cb5ffff552052e8f58f19b1d65bad597807e19..39ef81bd15cf6313dfde2ac16e5c5d5303393b7d 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -124,8 +124,7 @@ void ConvCompute(const ConvParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && - param.Input()->dims()[2] == param.Input()->dims()[3]) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), nullptr, false); } else if (param.Groups() == param.Input()->dims()[1] && diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index c8e969b854710f3b01eec6ebb4c38dd08d25fade..d039786d1febe4a8c63df98f1732ae0f9de98474 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -30,8 +30,7 @@ void DepthwiseConvCompute(const ConvParam ¶m) { Bias.mutable_data({param.Groups()}); if (param.Groups() == param.Input()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && - param.Input()->dims()[2] == param.Input()->dims()[3]) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), &Bias, false); } else if (param.Groups() == param.Input()->dims()[1] && diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index b6cf28a9ca665a1496ee8032f87c013137deade8..adaa6d2d9002892c7f563c1bee257a62a68592fb 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -257,8 +257,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, const int h = static_cast(input->dims()[2]); const int w = static_cast(input->dims()[3]); - const int l = h; - + // const int l = h; const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); const int hxw = h * w; @@ -271,7 +270,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, vbias = vdupq_n_f32(bias_data[j]); } - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 float w00 = filter_data_tmp[0]; float w01 = filter_data_tmp[1]; float w02 = filter_data_tmp[2]; @@ -283,39 +282,38 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, float w22 = filter_data_tmp[8]; output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[l] + w22 * input_data[l + 1]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + - w20 * input_data[2 * l - 2] + - w21 * input_data[2 * l - 1]; - output_data[(l - 1) * l] = - w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + - w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; - output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + - w01 * input_data[(l - 2) * (l + 1) + 1] + - w10 * input_data[l * l - 2] + - w11 * input_data[l * l - 1]; + w21 * input_data[w] + w22 * input_data[w + 1]; + output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + + w20 * input_data[2 * w - 2] + + w21 * input_data[2 * w - 1]; + output_data[(h - 1) * w] = + w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; + output_data[h * w - 1] = + w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + + w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; if (if_bias) { output_data[0] += bias_data[j]; - output_data[l - 1] += bias_data[j]; - output_data[(l - 1) * l] += bias_data[j]; - output_data[l * l - 1] += bias_data[j]; + output_data[w - 1] += bias_data[j]; + output_data[(h - 1) * w] += bias_data[j]; + output_data[h * w - 1] += bias_data[j]; } - for (int i = 1; i < l - 1; ++i) { - output_data[i * l] = - w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + - w11 * input_data[i * l] + w12 * input_data[i * l + 1] + - w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; - - output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + - w01 * input_data[i * l + l - 1 - l] + - w10 * input_data[i * l + l - 1 - 1] + - w11 * input_data[i * l + l - 1] + - w20 * input_data[i * l + l - 1 + l - 1] + - w21 * input_data[i * l + l - 1 + l]; + for (int i = 1; i < h - 1; ++i) { + output_data[i * w] = + w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + + w11 * input_data[i * w] + w12 * input_data[i * w + w] + + w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; + + output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + + w01 * input_data[i * w + w - 1 - w] + + w10 * input_data[i * w + w - 1 - 1] + + w11 * input_data[i * w + w - 1] + + w20 * input_data[i * w + w - 1 + w - 1] + + w21 * input_data[i * w + w - 1 + w]; if (if_bias) { - output_data[i * l] += bias_data[j]; - output_data[i * l + l - 1] += bias_data[j]; + output_data[i * w] += bias_data[j]; + output_data[i * w + w - 1] += bias_data[j]; } } @@ -325,15 +323,15 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + l); - const float *input_tmp_end = input_tmp + (l - 2) * l; + in2 = vld1q_f32(input_tmp + w); + const float *input_tmp_end = input_tmp + (h - 2) * w; in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + l); - int c_mid = l_mid; + in6 = vld1q_f32(input_tmp_end + w); + int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid > 3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + l + 4); + in3 = vld1q_f32(input_tmp + w + 4); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); @@ -352,7 +350,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, vst1q_f32(output_ptr, out0); in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); + in7 = vld1q_f32(input_tmp_end + w + 4); tmp0 = vextq_f32(in4, in5, 1); tmp1 = vextq_f32(in4, in5, 2); @@ -367,7 +365,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vaddq_f32(out0, vbias); - vst1q_f32(output_ptr + (l - 1) * l, out0); + vst1q_f32(output_ptr + (h - 1) * w, out0); // can optimize to each 8 stride. input_tmp += 4; @@ -380,8 +378,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); + float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); tmp0 = vextq_f32(in0, pad0, 1); tmp1 = vextq_f32(in0, pad0, 2); @@ -409,8 +407,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); + float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); tmp0 = vextq_f32(in4, pad2, 1); tmp1 = vextq_f32(in4, pad2, 2); @@ -427,28 +425,28 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, for (int i = 0; i < c_mid; ++i) { if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); } if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); } if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); } } // mid - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; + for (int i = 0; i < h - 2; ++i) { + auto output_ptr = output_data + (i + 1) * w + 1; + input_tmp = input_data + i * w; auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; + auto in2_tmp = vld1q_f32(input_tmp + w); + auto in4_tmp = vld1q_f32(input_tmp + w + w); + c_mid = w_mid; for (; c_mid > 3; c_mid -= 4) { auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + auto in3_tmp = vld1q_f32(input_tmp + w + 4); + auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); @@ -477,9 +475,9 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, in4_tmp = in5_tmp; } - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp1 = vextq_f32(in0_tmp, pad0, 2);