diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d71bc235977236fbd0dd332df556ea4bd41eacf4..6e5b467c9285c9a752b201c253080990d413893d 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -124,7 +124,8 @@ void ConvAddCompute(const FusionConvAddParam ¶m) { } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConv3x3(param.Input(), param.Strides(), // param.Paddings(), // param.Filter(), param.Bias(), diff --git a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h index a7d14fbad1e4b72a8571d13898e55a6cad8bf9a8..06c63c4a8d62d886f25465048faf6c109df0eafd 100644 --- a/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h @@ -118,14 +118,16 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index e7a8c7f52db327f3ff5871566c3557c484ba4d13..39ef81bd15cf6313dfde2ac16e5c5d5303393b7d 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -130,7 +130,8 @@ void ConvCompute(const ConvParam ¶m) { } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3) { + param.Filter()->dims()[2] == 3 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), param.Filter(), nullptr, param.Output(), false); } else { diff --git a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h index 7c31eed19693d20084e25daa485a0553d5d795f2..186f77a4cee42fba9a80d11f20f2f6fa6e2132eb 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h @@ -122,14 +122,16 @@ void ConvBNAddReluCompute(const FusionConvBNAddReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h index c6300f96e1b999c45538417c7b513068697ad4dd..27fe0a8a014ff11f96017cad3acc7557cbde5583 100644 --- a/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_bn_relu_arm_func.h @@ -117,14 +117,16 @@ void ConvBNReluCompute(const FusionConvBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h index 73170bdab922a46831334307aebc8af210ddfb73..d039786d1febe4a8c63df98f1732ae0f9de98474 100644 --- a/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/depthwise_conv_arm_func.h @@ -36,7 +36,8 @@ void DepthwiseConvCompute(const ConvParam ¶m) { } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConv3x3(param.Input(), param.Strides(), // param.Paddings(), // param.Filter(), &Bias, param.Output(), false); diff --git a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h index b60bf9b4d6df9d85cc2fbe378a3904c2d13e5e60..a9b2668b7be4ceb8717621b14aff4e58c81053de 100644 --- a/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h +++ b/src/operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h @@ -115,14 +115,16 @@ void DWConvBNReluCompute(const FusionDWConvBNReluParam ¶m) { if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 1 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { math::DepthwiseConvAddBNRelu3x3s1p1(param.Input(), param.Filter(), param.Output(), param.NewScale(), param.NewBias(), true); } else if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && - param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) { + param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2 && + param.Input()->dims()[2] == param.Input()->dims()[3]) { // math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(), // param.Output(), param.NewScale(), // param.NewBias(), 1); diff --git a/src/operators/math/depthwise_conv_3x3.cpp b/src/operators/math/depthwise_conv_3x3.cpp index b6cf28a9ca665a1496ee8032f87c013137deade8..adaa6d2d9002892c7f563c1bee257a62a68592fb 100644 --- a/src/operators/math/depthwise_conv_3x3.cpp +++ b/src/operators/math/depthwise_conv_3x3.cpp @@ -257,8 +257,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, const int h = static_cast(input->dims()[2]); const int w = static_cast(input->dims()[3]); - const int l = h; - + // const int l = h; const int batch_size = static_cast(input->dims()[0]); const int c = static_cast(input->dims()[1]); const int hxw = h * w; @@ -271,7 +270,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, vbias = vdupq_n_f32(bias_data[j]); } - int l_mid = l - 2; // l=1->l_mid=-1,l=2->l_mid=0 + int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 float w00 = filter_data_tmp[0]; float w01 = filter_data_tmp[1]; float w02 = filter_data_tmp[2]; @@ -283,39 +282,38 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, float w22 = filter_data_tmp[8]; output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[l] + w22 * input_data[l + 1]; - output_data[l - 1] = w10 * input_data[l - 2] + w11 * input_data[l - 1] + - w20 * input_data[2 * l - 2] + - w21 * input_data[2 * l - 1]; - output_data[(l - 1) * l] = - w01 * input_data[(l - 2) * l] + w02 * input_data[(l - 2) * l + 1] + - w11 * input_data[(l - 1) * l] + w12 * input_data[(l - 1) * l + 1]; - output_data[l * l - 1] = w00 * input_data[(l - 2) * (l + 1)] + - w01 * input_data[(l - 2) * (l + 1) + 1] + - w10 * input_data[l * l - 2] + - w11 * input_data[l * l - 1]; + w21 * input_data[w] + w22 * input_data[w + 1]; + output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + + w20 * input_data[2 * w - 2] + + w21 * input_data[2 * w - 1]; + output_data[(h - 1) * w] = + w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; + output_data[h * w - 1] = + w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + + w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; if (if_bias) { output_data[0] += bias_data[j]; - output_data[l - 1] += bias_data[j]; - output_data[(l - 1) * l] += bias_data[j]; - output_data[l * l - 1] += bias_data[j]; + output_data[w - 1] += bias_data[j]; + output_data[(h - 1) * w] += bias_data[j]; + output_data[h * w - 1] += bias_data[j]; } - for (int i = 1; i < l - 1; ++i) { - output_data[i * l] = - w01 * input_data[i * l - l] + w02 * input_data[i * l - l + 1] + - w11 * input_data[i * l] + w12 * input_data[i * l + 1] + - w21 * input_data[i * l + l] + w22 * input_data[i * l + l + 1]; - - output_data[i * l + l - 1] = w00 * input_data[i * l + l - 1 - l - 1] + - w01 * input_data[i * l + l - 1 - l] + - w10 * input_data[i * l + l - 1 - 1] + - w11 * input_data[i * l + l - 1] + - w20 * input_data[i * l + l - 1 + l - 1] + - w21 * input_data[i * l + l - 1 + l]; + for (int i = 1; i < h - 1; ++i) { + output_data[i * w] = + w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + + w11 * input_data[i * w] + w12 * input_data[i * w + w] + + w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; + + output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + + w01 * input_data[i * w + w - 1 - w] + + w10 * input_data[i * w + w - 1 - 1] + + w11 * input_data[i * w + w - 1] + + w20 * input_data[i * w + w - 1 + w - 1] + + w21 * input_data[i * w + w - 1 + w]; if (if_bias) { - output_data[i * l] += bias_data[j]; - output_data[i * l + l - 1] += bias_data[j]; + output_data[i * w] += bias_data[j]; + output_data[i * w + w - 1] += bias_data[j]; } } @@ -325,15 +323,15 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + l); - const float *input_tmp_end = input_tmp + (l - 2) * l; + in2 = vld1q_f32(input_tmp + w); + const float *input_tmp_end = input_tmp + (h - 2) * w; in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + l); - int c_mid = l_mid; + in6 = vld1q_f32(input_tmp_end + w); + int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid > 3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + l + 4); + in3 = vld1q_f32(input_tmp + w + 4); tmp0 = vextq_f32(in0, in1, 1); tmp1 = vextq_f32(in0, in1, 2); @@ -352,7 +350,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, vst1q_f32(output_ptr, out0); in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + l + 4); + in7 = vld1q_f32(input_tmp_end + w + 4); tmp0 = vextq_f32(in4, in5, 1); tmp1 = vextq_f32(in4, in5, 2); @@ -367,7 +365,7 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, out0 = vmlaq_n_f32(out0, tmp3, w12); out0 = vaddq_f32(out0, vbias); - vst1q_f32(output_ptr + (l - 1) * l, out0); + vst1q_f32(output_ptr + (h - 1) * w, out0); // can optimize to each 8 stride. input_tmp += 4; @@ -380,8 +378,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * l - 1]); + float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); tmp0 = vextq_f32(in0, pad0, 1); tmp1 = vextq_f32(in0, pad0, 2); @@ -409,8 +407,8 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, } // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[l * l - 1 - l]); - float32x4_t pad3 = vdupq_n_f32(input_data[l * l - 1]); + float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); + float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); tmp0 = vextq_f32(in4, pad2, 1); tmp1 = vextq_f32(in4, pad2, 2); @@ -427,28 +425,28 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, for (int i = 0; i < c_mid; ++i) { if (i == 0) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 0); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); } if (i == 1) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 1); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); } if (i == 2) { - vst1q_lane_f32(output_ptr + (l - 1) * l + i, out0, 2); + vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); } } // mid - for (int i = 0; i < l - 2; ++i) { - auto output_ptr = output_data + (i + 1) * l + 1; - input_tmp = input_data + i * l; + for (int i = 0; i < h - 2; ++i) { + auto output_ptr = output_data + (i + 1) * w + 1; + input_tmp = input_data + i * w; auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + l); - auto in4_tmp = vld1q_f32(input_tmp + l + l); - c_mid = l_mid; + auto in2_tmp = vld1q_f32(input_tmp + w); + auto in4_tmp = vld1q_f32(input_tmp + w + w); + c_mid = w_mid; for (; c_mid > 3; c_mid -= 4) { auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + l + 4); - auto in5_tmp = vld1q_f32(input_tmp + l + l + 4); + auto in3_tmp = vld1q_f32(input_tmp + w + 4); + auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); @@ -477,9 +475,9 @@ void DepthwiseConv3x3s1p1(const Tensor *input, const Tensor *filter, in4_tmp = in5_tmp; } - float32x4_t pad0 = vdupq_n_f32(input_data[i * l + l - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * l + l - 1 + l]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * l + l - 1 + l + l]); + float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); + float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); + float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); tmp0 = vextq_f32(in0_tmp, pad0, 1); tmp1 = vextq_f32(in0_tmp, pad0, 2);