From 3582b9a028dddff1bdda7e165b2bc321edba00ed Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Fri, 8 Mar 2019 11:44:39 +0800 Subject: [PATCH] Add depthwise conv5x5 armv8 implementation --- .../convolution/conv_add_bn_relu_kernel.cpp | 2 - .../arm/convolution/conv_add_kernel.cpp | 2 - .../arm/convolution/conv_add_relu_kernel.cpp | 2 - .../convolution/conv_bn_add_relu_kernel.cpp | 2 - .../arm/convolution/conv_bn_relu_kernel.cpp | 2 - .../kernel/arm/convolution/conv_common.cpp | 2 - .../kernel/arm/convolution/conv_kernel.cpp | 2 - .../kernel/central-arm-func/conv_arm_func.h | 3 +- src/operators/math/depthwise_conv5x5.cpp | 373 +++++++++++++++++- .../math/winograd/winograd_transform_f6k3.cpp | 34 +- 10 files changed, 390 insertions(+), 34 deletions(-) diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index 91e901b89f..ae67147ffd 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -73,13 +73,11 @@ void ConvAddBNReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp index 9de6e333e7..76c2200df3 100644 --- a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -43,13 +43,11 @@ void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::AddChannelWise(param.Output(), param.Bias(), diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index 3f8a85e74b..e0387f6444 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -41,12 +41,10 @@ void ConvAddReluKernel::Compute( param.Paddings(), param.Output()); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); diff --git a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index a46c66c275..f591833887 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -73,13 +73,11 @@ void ConvBNAddReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index b49120e740..352df8a389 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -72,13 +72,11 @@ void ConvBNReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 8981715351..8db3b36cf4 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -49,11 +49,9 @@ void InitBaseConvKernel(ConvParam *param) { } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 2) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT; -#ifndef __aarch64__ } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; -#endif } else if (conv3x3 && !depth3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp index 97c153fa28..6771b88d4b 100644 --- a/src/operators/kernel/arm/convolution/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -51,11 +51,9 @@ void ConvKernel::Compute(const ConvParam ¶m) { math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), param.Paddings(), param.Output()); break; -#ifndef __aarch64__ case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); break; -#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); break; diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 7170c3ff4d..bf7e3b0abc 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -164,6 +164,7 @@ inline void WinogradConv3x3(const ConvParam ¶m) { } #ifndef __aarch64__ +// int8 DepthwiseConv3x3 template inline void DepthwiseConv3x3(const ConvParam ¶m) { const Tensor *input = param.Input(); @@ -188,6 +189,7 @@ inline void DepthwiseConv3x3(const ConvParam ¶m) { } } } +#endif // __aarch64__ template inline void DepthwiseConv5x5(const ConvParam ¶m) { @@ -210,7 +212,6 @@ inline void DepthwiseConv5x5(const ConvParam ¶m) { GemmConv(param); } } -#endif // __aarch64__ template void ConvAddReluBasic(const ParamType ¶m) { diff --git a/src/operators/math/depthwise_conv5x5.cpp b/src/operators/math/depthwise_conv5x5.cpp index 792a98659e..2c0185b236 100644 --- a/src/operators/math/depthwise_conv5x5.cpp +++ b/src/operators/math/depthwise_conv5x5.cpp @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#if defined(__ARM_NEON__) && !defined(__aarch64__) +#if defined(__ARM_NEON__) || defined(__ARM_NEON) #include "operators/math/depthwise_conv5x5.h" #include @@ -243,7 +243,224 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, output_ptr0 += valid_w_start; output_ptr1 += valid_w_start; } - // valid + // valid +// #if __aarch64__ +#if 0 + float32x4_t _q14, _q15; + for (int loop = 0; loop = output_w_tiles; ++loop) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q15 = vmulq_lane_f32(_q9, vget_low_f32(_ker[5]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[1]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + _q11 = vld1q_f32(input_ptr5); + _q12 = vld1q_f32(input_ptr5 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q15 = vmlaq_lane_f32(_q15, _q7, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q8, vget_high_f32(_ker[2]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q15 = vmlaq_lane_f32(_q15, _q9, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[3]), 1); + + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[4]), 1); + + vst1q_f32(output_ptr0, _q14); + vst1q_f32(output_ptr1, _q15); + + input_ptr0 += 4; + input_ptr1 += 4; + input_ptr2 += 4; + input_ptr3 += 4; + input_ptr4 += 4; + input_ptr5 += 4; + output_ptr0 += 4; + output_ptr1 += 4; + } + // remain w + if (output_w_remain > 0) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q15 = vmulq_lane_f32(_q9, vget_low_f32(_ker[5]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[1]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + _q11 = vld1q_f32(input_ptr5); + _q12 = vld1q_f32(input_ptr5 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q15 = vmlaq_lane_f32(_q15, _q7, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q8, vget_high_f32(_ker[2]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q15 = vmlaq_lane_f32(_q15, _q9, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[3]), 1); + + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[4]), 1); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _q14, 2); + vst1q_lane_f32(output_ptr1 + 2, _q15, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_q14)); + vst1_f32(output_ptr1, vget_low_f32(_q15)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _q14, 0); + vst1q_lane_f32(output_ptr1, _q15, 0); + break; + } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + input_ptr4 += output_w_remain; + input_ptr5 += output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; + } +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -443,6 +660,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, [kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6]) : "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { float32x4_t row0 = vld1q_f32(input_ptr0); @@ -540,7 +758,155 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, } output_ptr0 += valid_w_start; } - // valid + // valid +// #if __aarch64__ +#if 0 + float32x4_t _q14; + for (int loop = 0; loop = output_w_tiles; ++loop) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + + vst1q_f32(output_ptr0, _q14); + + input_ptr0 += 4; + input_ptr1 += 4; + input_ptr2 += 4; + input_ptr3 += 4; + input_ptr4 += 4; + output_ptr0 += 4; + } + // remain w + if (output_w_remain > 0) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _q14, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_q14)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _q14, 0); + break; + } + + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + input_ptr4 += output_w_remain; + output_ptr0 += output_w_remain; + } +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -676,6 +1042,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, [kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6]) : "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { float32x4_t row0 = vld1q_f32(input_ptr0); diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index 11366c93a7..7972196656 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -544,12 +544,12 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, for (int l = 0; l < 2; ++l) { float32x4x2_t _q23, _q45, _q67, _q89; _q23.val[0] = vld1q_f32(ptr0); - _q23.val[0] = vld1q_f32(ptr0 + 4); - _q45.val[1] = vld1q_f32(ptr0 + 8); + _q23.val[1] = vld1q_f32(ptr0 + 4); + _q45.val[0] = vld1q_f32(ptr0 + 8); _q45.val[1] = vld1q_f32(ptr0 + 12); _q67.val[0] = vld1q_f32(ptr1); - _q67.val[0] = vld1q_f32(ptr1 + 4); - _q89.val[1] = vld1q_f32(ptr1 + 8); + _q67.val[1] = vld1q_f32(ptr1 + 4); + _q89.val[0] = vld1q_f32(ptr1 + 8); _q89.val[1] = vld1q_f32(ptr1 + 12); _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); @@ -1167,12 +1167,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float32x4_t _q0 = vld1q_f32(transform_matrix); float32x4x2_t _q23, _q45, _q67, _q89; _q23.val[0] = vld1q_f32(at_m_ptr0); - _q23.val[0] = vld1q_f32(at_m_ptr0 + 4); - _q45.val[1] = vld1q_f32(at_m_ptr0 + 8); + _q23.val[1] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[0] = vld1q_f32(at_m_ptr0 + 8); _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); _q67.val[0] = vld1q_f32(at_m_ptr1); - _q67.val[0] = vld1q_f32(at_m_ptr1 + 4); - _q89.val[1] = vld1q_f32(at_m_ptr1 + 8); + _q67.val[1] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[0] = vld1q_f32(at_m_ptr1 + 8); _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); @@ -1231,6 +1231,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, _q1 = vaddq_f32(_q9, _q7); _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); _q2 = vaddq_f32(_q12, _q8); + _q2 = vaddq_f32(_q2, _q14); _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); _q23 = vtrnq_f32(_q1, _q2); vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); @@ -1287,8 +1288,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, _d20 = vadd_f32(_d20, _d15); _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); _d18d20 = vtrn_f32(_d18, _d20); - vst1_f32(out_ptr4 + 4, _d18); - vst1_f32(out_ptr5 + 4, _d20); + vst1_f32(out_ptr4 + 4, _d18d20.val[0]); + vst1_f32(out_ptr5 + 4, _d18d20.val[1]); #else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" @@ -1428,12 +1429,12 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float32x4_t _q0 = vld1q_f32(transform_matrix); float32x4x2_t _q23, _q45, _q67, _q89; _q23.val[0] = vld1q_f32(at_m_ptr0); - _q23.val[0] = vld1q_f32(at_m_ptr0 + 4); - _q45.val[1] = vld1q_f32(at_m_ptr0 + 8); + _q23.val[1] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[0] = vld1q_f32(at_m_ptr0 + 8); _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); _q67.val[0] = vld1q_f32(at_m_ptr1); - _q67.val[0] = vld1q_f32(at_m_ptr1 + 4); - _q89.val[1] = vld1q_f32(at_m_ptr1 + 8); + _q67.val[1] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[0] = vld1q_f32(at_m_ptr1 + 8); _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); @@ -1489,6 +1490,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, _q1 = vaddq_f32(_q9, _q7); _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); _q2 = vaddq_f32(_q12, _q8); + _q2 = vaddq_f32(_q2, _q14); _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); _q23 = vtrnq_f32(_q1, _q2); vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); @@ -1545,8 +1547,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, _d20 = vadd_f32(_d20, _d15); _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); _d18d20 = vtrn_f32(_d18, _d20); - vst1_f32(out_ptr4 + 4, _d18); - vst1_f32(out_ptr5 + 4, _d20); + vst1_f32(out_ptr4 + 4, _d18d20.val[0]); + vst1_f32(out_ptr5 + 4, _d18d20.val[1]); #else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" -- GitLab