From 643a62e1c4625aa04892d876ae885ae7f69cc724 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 7 Mar 2019 21:00:30 +0800 Subject: [PATCH] Add winograd implementation for arm64 --- .../convolution/conv_add_bn_relu_kernel.cpp | 2 +- .../arm/convolution/conv_add_kernel.cpp | 4 +- .../arm/convolution/conv_add_relu_kernel.cpp | 2 +- .../convolution/conv_bn_add_relu_kernel.cpp | 2 +- .../arm/convolution/conv_bn_relu_kernel.cpp | 2 +- .../kernel/arm/convolution/conv_common.cpp | 2 +- .../kernel/arm/convolution/conv_kernel.cpp | 2 +- .../math/winograd/winograd_transform_f6k3.cpp | 642 +++++++++++++++++- .../winograd_transform_f6k3_arm64.cpp | 413 ----------- 9 files changed, 647 insertions(+), 424 deletions(-) delete mode 100644 src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index 28cb2c3e40..91e901b89f 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -79,12 +79,12 @@ void ConvAddBNReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvBNReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp index b62fdf71f8..9de6e333e7 100644 --- a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef FUSION_CONVADD_OP #include "operators/kernel/conv_add_kernel.h" #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_add_arm_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -47,12 +49,12 @@ void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvAddBasic(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index 4060c56312..3f8a85e74b 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -46,11 +46,11 @@ void ConvAddReluKernel::Compute( DepthwiseConv5x5(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvAddReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index 86a9eb2250..a46c66c275 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -79,12 +79,12 @@ void ConvBNAddReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvBNReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index 0de0704884..b49120e740 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -78,12 +78,12 @@ void ConvBNReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvBNReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 3ff08eae0d..8981715351 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -53,6 +53,7 @@ void InitBaseConvKernel(ConvParam *param) { } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; +#endif } else if (conv3x3 && !depth3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && @@ -68,7 +69,6 @@ void InitBaseConvKernel(ConvParam *param) { param->transformed_filter_ = new framework::LoDTensor; operators::math::winograd_transform_weight<8, 3>( *param->Filter(), param->transformed_filter_); -#endif } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; } diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp index 59026be4d6..97c153fa28 100644 --- a/src/operators/kernel/arm/convolution/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -55,10 +55,10 @@ void ConvKernel::Compute(const ConvParam ¶m) { case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index f95ead6244..11366c93a7 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -15,10 +15,10 @@ limitations under the License. */ // Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn // project. +#if defined(__ARM_NEON) || defined(__ARM_NEON__) #ifdef CONV_OP -#ifndef __aarch64__ - +#include #include "operators/math/pad.h" #include "operators/math/winograd/winograd_transform.h" @@ -51,6 +51,10 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180}; const float *inptr = weight.data(); + +#if __aarch64__ + int remain_start = 0; +#else int remain_start = out_channel & 0xFFFC; #pragma omp parallel for @@ -256,6 +260,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, "q13", "r0"); } } +#endif // __aarch64__ // remain output channel #pragma omp parallel for @@ -358,6 +363,90 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, const float *in2 = in1 + width; const float *in3 = in2 + width; float *d_bt_ptr = d_bt; +#if __aarch64__ + int steps = 4 * width; + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4_t _q1 = vld1q_f32(transform_matrix + 4); + for (int l = 0; l < 2; ++l) { + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(in0); + _q45.val[0] = vld1q_f32(in0 + 4); + _q23.val[1] = vld1q_f32(in1); + _q45.val[1] = vld1q_f32(in1 + 4); + _q67.val[0] = vld1q_f32(in2); + _q89.val[0] = vld1q_f32(in2 + 4); + _q67.val[1] = vld1q_f32(in3); + _q89.val[1] = vld1q_f32(in3 + 4); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q67.val[0])); + float32x4_t _q4 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q67.val[1])); + float32x4_t _q3 = vcombine_f32(vget_low_f32(_q45.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q45.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q6 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q67.val[0])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q67.val[1])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q45.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q9 = vcombine_f32(vget_high_f32(_q45.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q10 = vsubq_f32(_q2, _q7); + float32x4_t _q11 = vsubq_f32(_q3, _q6); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_f32(d_bt_ptr, _q10); + + _q10 = vaddq_f32(_q6, _q7); + _q11 = vaddq_f32(_q4, _q5); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 0); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 0); + float32x4_t _q12 = vaddq_f32(_q10, _q11); + float32x4_t _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 4, _q12); + vst1q_f32(d_bt_ptr + 8, _q13); + + _q10 = vmulq_lane_f32(_q6, vget_high_f32(_q1), 1); + _q11 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 0); + _q10 = vaddq_f32(_q10, _q7); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_low_f32(_q1), 1); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1); + _q12 = vaddq_f32(_q10, _q11); + _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 12, _q12); + vst1q_f32(d_bt_ptr + 16, _q13); + + _q10 = vmulq_lane_f32(_q6, vget_low_f32(_q1), 0); + _q11 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 1); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1); + _q10 = vmlaq_lane_f32(_q10, _q7, vget_high_f32(_q1), 0); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q1), 0); + _q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0); + _q12 = vaddq_f32(_q10, _q11); + _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 20, _q12); + vst1q_f32(d_bt_ptr + 24, _q13); + + _q10 = vsubq_f32(_q9, _q4); + _q11 = vsubq_f32(_q8, _q5); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_f32(d_bt_ptr + 28, _q10); + + in0 += steps; + in1 += steps; + in2 += steps; + in3 += steps; + d_bt_ptr += 32; + } +#else int steps = 4 * width * sizeof(float); asm volatile( "vld1.32 {d0-d3}, [%[tm_ptr]] \n" @@ -434,7 +523,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); - +#endif // __aarch64__ float *ptr0 = d_bt; float *ptr1 = ptr0 + 32; int tile_indics = h * w_tiles + w; @@ -450,6 +539,120 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, float *out5 = out4 + channel * 8; float *out6 = out5 + channel * 8; float *out7 = out6 + channel * 8; +#if __aarch64__ + steps = 8 * channel * 8; + for (int l = 0; l < 2; ++l) { + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(ptr0); + _q23.val[0] = vld1q_f32(ptr0 + 4); + _q45.val[1] = vld1q_f32(ptr0 + 8); + _q45.val[1] = vld1q_f32(ptr0 + 12); + _q67.val[0] = vld1q_f32(ptr1); + _q67.val[0] = vld1q_f32(ptr1 + 4); + _q89.val[1] = vld1q_f32(ptr1 + 8); + _q89.val[1] = vld1q_f32(ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q9 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q10 = vsubq_f32(_q2, _q8); + float32x4_t _q11 = vsubq_f32(_q6, _q4); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_lane_f32(out0, _q10, 0); + vst1q_lane_f32(out0 + steps, _q10, 1); + vst1q_lane_f32(out0 + 2 * steps, _q10, 2); + vst1q_lane_f32(out0 + 3 * steps, _q10, 3); + + _q10 = vaddq_f32(_q4, _q8); + _q11 = vaddq_f32(_q3, _q7); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 0); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 0); + float32x4_t _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out1, _q12, 0); + vst1q_lane_f32(out1 + steps, _q12, 1); + vst1q_lane_f32(out1 + 2 * steps, _q12, 2); + vst1q_lane_f32(out1 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out2, _q12, 0); + vst1q_lane_f32(out2 + steps, _q12, 1); + vst1q_lane_f32(out2 + 2 * steps, _q12, 2); + vst1q_lane_f32(out2 + 3 * steps, _q12, 3); + + _q10 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 1); + _q11 = vmulq_lane_f32(_q3, vget_high_f32(_q1), 0); + _q10 = vaddq_f32(_q10, _q8); + _q11 = vmlaq_lane_f32(_q11, _q7, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_low_f32(_q1), 1); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1); + _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out3, _q12, 0); + vst1q_lane_f32(out3 + steps, _q12, 1); + vst1q_lane_f32(out3 + 2 * steps, _q12, 2); + vst1q_lane_f32(out3 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out4, _q12, 0); + vst1q_lane_f32(out4 + steps, _q12, 1); + vst1q_lane_f32(out4 + 2 * steps, _q12, 2); + vst1q_lane_f32(out4 + 3 * steps, _q12, 3); + + _q10 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0); + _q11 = vmulq_lane_f32(_q3, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 1); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1); + _q10 = vmlaq_lane_f32(_q10, _q8, vget_high_f32(_q1), 0); + _q11 = vmlaq_lane_f32(_q11, _q7, vget_high_f32(_q1), 0); + _q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0); + _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out5, _q12, 0); + vst1q_lane_f32(out5 + steps, _q12, 1); + vst1q_lane_f32(out5 + 2 * steps, _q12, 2); + vst1q_lane_f32(out5 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out6, _q12, 0); + vst1q_lane_f32(out6 + steps, _q12, 1); + vst1q_lane_f32(out6 + 2 * steps, _q12, 2); + vst1q_lane_f32(out6 + 3 * steps, _q12, 3); + + _q10 = vsubq_f32(_q9, _q3); + _q11 = vsubq_f32(_q5, _q7); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_lane_f32(out7, _q10, 0); + vst1q_lane_f32(out7 + steps, _q10, 1); + vst1q_lane_f32(out7 + 2 * steps, _q10, 2); + vst1q_lane_f32(out7 + 3 * steps, _q10, 3); + + ptr0 += 16; + ptr1 += 16; + out0 += 4 * steps; + out1 += 4 * steps; + out2 += 4 * steps; + out3 += 4 * steps; + out4 += 4 * steps; + out5 += 4 * steps; + out6 += 4 * steps; + out7 += 4 * steps; + } +#else steps = 8 * channel * 8 * sizeof(float); asm volatile( "mov r0, #2 \n" @@ -555,6 +758,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); +#endif // __aarch64__ } } } @@ -587,6 +791,71 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8; int inter_channel = in_channel >> 1; int remain_channel = in_channel & 0x1; +#if __aarch64__ + asm volatile( + "dup v8.4s, wzr \n" + "dup v9.4s, wzr \n" + "dup v10.4s, wzr \n" + "dup v11.4s, wzr \n" + "dup v12.4s, wzr \n" + "dup v13.4s, wzr \n" + "dup v14.4s, wzr \n" + "dup v15.4s, wzr \n" + + "cmp %[inter], #0 \n" + "ble loop_1c_%= \n" + // loop 2 channels + "loop_2c_%=: \n" + "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[in_ptr]], #32 \n" + + "fmla v8.4s, v2.4s, v0.s[0] \n" + "fmla v9.4s, v3.4s, v0.s[0] \n" + "fmla v10.4s, v2.4s, v0.s[1] \n" + "fmla v11.4s, v3.4s, v0.s[1] \n" + "fmla v12.4s, v2.4s, v0.s[2] \n" + "fmla v13.4s, v3.4s, v0.s[2] \n" + "fmla v14.4s, v2.4s, v0.s[3] \n" + "fmla v15.4s, v3.4s, v0.s[3] \n" + + "fmla v8.4s, v4.4s, v1.s[0] \n" + "fmla v9.4s, v5.4s, v1.s[0] \n" + "fmla v10.4s, v4.4s, v1.s[1] \n" + "fmla v11.4s, v5.4s, v1.s[1] \n" + "fmla v12.4s, v4.4s, v1.s[2] \n" + "fmla v13.4s, v5.4s, v1.s[2] \n" + "fmla v14.4s, v4.4s, v1.s[3] \n" + "fmla v15.4s, v5.4s, v1.s[3] \n" + + "subs %[inter], %[inter], #1 \n" + "bne loop_2c_%= \n" + + // loop 1 channel + "loop_1c_%=: \n" + "cmp %[remain], #0 \n" + "ble store_res_%= \n" + + "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" + "fmla v8.4s, v2.4s, v0.s[0] \n" + "fmla v9.4s, v3.4s, v0.s[0] \n" + "fmla v10.4s, v2.4s, v0.s[1] \n" + "fmla v11.4s, v3.4s, v0.s[1] \n" + "fmla v12.4s, v2.4s, v0.s[2] \n" + "fmla v13.4s, v3.4s, v0.s[2] \n" + "fmla v14.4s, v2.4s, v0.s[3] \n" + "fmla v15.4s, v3.4s, v0.s[3] \n" + + "store_res_%=: \n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[uv_ptr]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[uv_ptr]], #64 \n" + : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), + [inter] "+r"(inter_channel) + : [remain] "r"(remain_channel) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"); +#else asm volatile( "veor q8, q8, q8 \n" "veor q9, q9, q9 \n" @@ -651,6 +920,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [remain_channel] "r"(remain_channel) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ } } } @@ -686,6 +956,116 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, int tile_block = tile_indics >> 3; int block_indics = tile_indics & 0x7; const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + for (int l = 0; l < 2; ++l) { + float32x4_t _q1, _q2, _q3, _q4, _q5, _q6, _q7, _q8; + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 0); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 0); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 0); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 0); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 0); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 0); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 0); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 0); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 1); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 1); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 1); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 1); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 1); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 1); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 1); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 1); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 2); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 2); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 2); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 2); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 2); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 2); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 2); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 2); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 3); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 3); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 3); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 3); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 3); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 3); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 3); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 3); + uv_ptr0 += 32; + + float32x4_t _q9 = vaddq_f32(_q3, _q5); + float32x4_t _q10 = vaddq_f32(_q7, _q2); + float32x4_t _q11 = vaddq_f32(_q4, _q6); + float32x4_t _q12 = vsubq_f32(_q3, _q5); + float32x4_t _q13 = vsubq_f32(_q7, _q2); + float32x4_t _q14 = vsubq_f32(_q4, _q6); + _q2 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q3 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + float32x4_t _q15 = vaddq_f32(_q1, _q9); + _q15 = vaddq_f32(_q15, _q10); + _q15 = vmlaq_lane_f32(_q15, _q3, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr, _q15); + + _q15 = vaddq_f32(_q12, _q2); + _q15 = vmlaq_lane_f32(_q15, _q14, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 4, _q15); + + _q15 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_high_f32(_q0), 0); + vst1q_f32(at_m_ptr + 8, _q15); + + _q15 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q15 = vmlaq_lane_f32(_q15, _q14, vget_low_f32(_q0), 1); + vst1q_f32(at_m_ptr + 12, _q15); + + _q15 = vaddq_f32(_q9, _q3); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 16, _q15); + + _q15 = vaddq_f32(_q12, _q8); + _q15 = vaddq_f32(_q15, _q14); + _q15 = vmlaq_lane_f32(_q15, _q2, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 20, _q15); + + at_m_ptr += 24; + } +#else int steps = 32 * sizeof(float); asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" @@ -771,6 +1151,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ float *at_m_ptr0 = at_m; float *at_m_ptr1 = at_m + 24; @@ -782,6 +1163,133 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float *out_ptr3 = output_tmp + 18; float *out_ptr4 = output_tmp + 24; float *out_ptr5 = output_tmp + 30; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(at_m_ptr0); + _q23.val[0] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 8); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); + _q67.val[0] = vld1q_f32(at_m_ptr1); + _q67.val[0] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 8); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q9 = vaddq_f32(_q2, _q3); + float32x4_t _q10 = vaddq_f32(_q4, _q5); + float32x4_t _q11 = vaddq_f32(_q6, _q7); + float32x4_t _q12 = vsubq_f32(_q2, _q3); + float32x4_t _q13 = vsubq_f32(_q4, _q5); + float32x4_t _q14 = vsubq_f32(_q6, _q7); + _q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + _q1 = vaddq_f32(_q1, _q9); + _q1 = vaddq_f32(_q1, _q10); + _q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1); + + _q2 = vaddq_f32(_q12, _q6); + _q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1); + + _q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0); + + _q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1); + + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + vst1_f32(out_ptr0, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0])); + vst1_f32(out_ptr1, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1])); + vst1_f32(out_ptr2, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0])); + vst1_f32(out_ptr3, vget_high_f32(_q23.val[1])); + vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1])); + + _q1 = vaddq_f32(_q9, _q7); + _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q8); + _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); + _q23 = vtrnq_f32(_q1, _q2); + vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1])); + + // remain 2 rows + _q1 = vld1q_f32(at_m_ptr0 + 16); + _q2 = vld1q_f32(at_m_ptr0 + 20); + _q3 = vld1q_f32(at_m_ptr1 + 16); + _q4 = vld1q_f32(at_m_ptr1 + 20); + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + + float32x2_t _d2 = vget_low_f32(_q23.val[0]); + float32x2_t _d3 = vget_high_f32(_q23.val[0]); + float32x2_t _d4 = vget_low_f32(_q23.val[1]); + float32x2_t _d5 = vget_high_f32(_q23.val[1]); + float32x2_t _d6 = vget_low_f32(_q45.val[0]); + float32x2_t _d7 = vget_high_f32(_q45.val[0]); + float32x2_t _d8 = vget_low_f32(_q45.val[1]); + float32x2_t _d9 = vget_high_f32(_q45.val[1]); + + float32x2_t _d10 = vadd_f32(_d4, _d3); + float32x2_t _d11 = vadd_f32(_d5, _d6); + float32x2_t _d12 = vadd_f32(_d8, _d7); + float32x2_t _d13 = vsub_f32(_d4, _d3); + float32x2_t _d14 = vsub_f32(_d5, _d6); + float32x2_t _d15 = vsub_f32(_d8, _d7); + float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0); + float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0); + + float32x2_t _d18 = vadd_f32(_d2, _d10); + float32x2_t _d20 = vadd_f32(_d13, _d16); + float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1); + float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0); + _d18 = vadd_f32(_d18, _d11); + _d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1); + _d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1); + _d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0); + _d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1); + + float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20); + float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21); + vst1_f32(out_ptr4, _d18d20.val[0]); + vst1_f32(out_ptr4 + 2, _d19d21.val[0]); + vst1_f32(out_ptr5, _d18d20.val[1]); + vst1_f32(out_ptr5 + 2, _d19d21.val[1]); + + _d18 = vadd_f32(_d10, _d17); + _d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1); + _d20 = vadd_f32(_d13, _d9); + _d20 = vadd_f32(_d20, _d15); + _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); + _d18d20 = vtrn_f32(_d18, _d20); + vst1_f32(out_ptr4 + 4, _d18); + vst1_f32(out_ptr5 + 4, _d20); +#else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" // process 4 rows @@ -898,6 +1406,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; float *out_ptr = output_ptr + offset; int remain_row = out_h - 6 * tile_h; @@ -915,6 +1424,130 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float *out_ptr3 = out_ptr2 + out_w; float *out_ptr4 = out_ptr3 + out_w; float *out_ptr5 = out_ptr4 + out_w; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(at_m_ptr0); + _q23.val[0] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 8); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); + _q67.val[0] = vld1q_f32(at_m_ptr1); + _q67.val[0] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 8); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q9 = vaddq_f32(_q2, _q3); + float32x4_t _q10 = vaddq_f32(_q4, _q5); + float32x4_t _q11 = vaddq_f32(_q6, _q7); + float32x4_t _q12 = vsubq_f32(_q2, _q3); + float32x4_t _q13 = vsubq_f32(_q4, _q5); + float32x4_t _q14 = vsubq_f32(_q6, _q7); + _q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + _q1 = vaddq_f32(_q1, _q9); + _q1 = vaddq_f32(_q1, _q10); + _q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q6); + _q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1); + + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + vst1_f32(out_ptr0, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0])); + vst1_f32(out_ptr1, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1])); + vst1_f32(out_ptr2, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0])); + vst1_f32(out_ptr3, vget_high_f32(_q23.val[1])); + vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1])); + + _q1 = vaddq_f32(_q9, _q7); + _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q8); + _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); + _q23 = vtrnq_f32(_q1, _q2); + vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1])); + + // remain 2 rows + _q1 = vld1q_f32(at_m_ptr0 + 16); + _q2 = vld1q_f32(at_m_ptr0 + 20); + _q3 = vld1q_f32(at_m_ptr1 + 16); + _q4 = vld1q_f32(at_m_ptr1 + 20); + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + + float32x2_t _d2 = vget_low_f32(_q23.val[0]); + float32x2_t _d3 = vget_high_f32(_q23.val[0]); + float32x2_t _d4 = vget_low_f32(_q23.val[1]); + float32x2_t _d5 = vget_high_f32(_q23.val[1]); + float32x2_t _d6 = vget_low_f32(_q45.val[0]); + float32x2_t _d7 = vget_high_f32(_q45.val[0]); + float32x2_t _d8 = vget_low_f32(_q45.val[1]); + float32x2_t _d9 = vget_high_f32(_q45.val[1]); + + float32x2_t _d10 = vadd_f32(_d4, _d3); + float32x2_t _d11 = vadd_f32(_d5, _d6); + float32x2_t _d12 = vadd_f32(_d8, _d7); + float32x2_t _d13 = vsub_f32(_d4, _d3); + float32x2_t _d14 = vsub_f32(_d5, _d6); + float32x2_t _d15 = vsub_f32(_d8, _d7); + float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0); + float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0); + + float32x2_t _d18 = vadd_f32(_d2, _d10); + float32x2_t _d20 = vadd_f32(_d13, _d16); + float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1); + float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0); + _d18 = vadd_f32(_d18, _d11); + _d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1); + _d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1); + _d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0); + _d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1); + + float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20); + float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21); + vst1_f32(out_ptr4, _d18d20.val[0]); + vst1_f32(out_ptr4 + 2, _d19d21.val[0]); + vst1_f32(out_ptr5, _d18d20.val[1]); + vst1_f32(out_ptr5 + 2, _d19d21.val[1]); + + _d18 = vadd_f32(_d10, _d17); + _d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1); + _d20 = vadd_f32(_d13, _d9); + _d20 = vadd_f32(_d20, _d15); + _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); + _d18d20 = vtrn_f32(_d18, _d20); + vst1_f32(out_ptr4 + 4, _d18); + vst1_f32(out_ptr5 + 4, _d20); +#else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" // process 4 rows @@ -1031,6 +1664,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ } } } @@ -1041,5 +1675,5 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, } // namespace operators } // namespace paddle_mobile -#endif // __aarch64__ #endif // CONV_OP +#endif // __ARM_NEON__ diff --git a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp b/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp deleted file mode 100644 index 5ef9c194f2..0000000000 --- a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// We refer https://github.com/andravin/wincnn to access the winograd transform -// matrixs - -#ifdef CONV_OP -#ifdef __aarch64__ - -#include "operators/math/winograd/winograd_transform.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -template <> -void winograd_transform_weight<8, 3>(const framework::Tensor &weight, - framework::Tensor *output) { - // weight shape is [out_channel, in_channel, kernel_h, kernel_w] - int out_channel = weight.dims()[0]; - int in_channel = weight.dims()[1]; - // reshape and alloc transformed weight - framework::DDim transformed_shape = - framework::make_ddim(std::vector{out_channel, in_channel, 64}); - float *outptr = output->mutable_data(transformed_shape); - const float *inptr = weight.data(); - for (int oc = 0; oc < out_channel; ++oc) { - for (int ic = 0; ic < in_channel; ++ic) { - size_t offset = oc * in_channel + ic; - float *kout = outptr + offset * 64; - const float *k = inptr + offset * 9; - - float gw[3][8]; - for (int i = 0; i < 3; ++i, k += 3) { - float g0 = k[0]; - float g1 = k[1]; - float g2 = k[2]; - float d0 = g0 + g2; - float d1 = g0 + 4 * g2; - float d2 = g2 + 4 * g0; - float d3 = 2 * g1; - gw[i][0] = g0; - gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2) - gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2) - gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2) - gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2) - gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2) - gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2) - gw[i][7] = g2; - } - for (int i = 0; i < 8; ++i, kout += 8) { - float g0 = gw[0][i]; - float g1 = gw[1][i]; - float g2 = gw[2][i]; - float d0 = g0 + g2; - float d1 = g0 + 4 * g2; - float d2 = g2 + 4 * g0; - float d3 = 2 * g1; - kout[0] = g0; - kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2) - kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2) - kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2) - kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2) - kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2) - kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2) - kout[7] = g2; - } - } - } -} - -template <> -void winograd_transform_input<8, 3>(const framework::Tensor &input, - framework::Tensor *output) { - // tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation - int channel = input.dims()[1]; - int height = input.dims()[2]; - int width = input.dims()[3]; - int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6 - int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6 - framework::DDim transformed_shape = - framework::make_ddim(std::vector{channel, h_tiles, w_tiles, 64}); - float *outptr = output->mutable_data(transformed_shape); - memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float)); - const float *inptr = input.data(); - // pack input to tiles - for (int c = 0; c < channel; ++c) { - int inter_h = (height - 2) / 6; - int inter_w = (width - 2) / 6; - int remain_h = height - (inter_h * 6); - int remain_w = width - (inter_w * 6); - const float *in0 = inptr + c * height * width; - const float *in1 = in0 + width; - const float *in2 = in1 + width; - const float *in3 = in2 + width; - const float *in4 = in3 + width; - const float *in5 = in4 + width; - const float *in6 = in5 + width; - const float *in7 = in6 + width; - float *out = outptr + c * h_tiles * w_tiles * 64; - - for (int h = 0; h < inter_h; ++h) { - for (int w = 0; w < inter_w; ++w) { - memcpy(out, in0, 8 * sizeof(float)); - memcpy(out + 8, in1, 8 * sizeof(float)); - memcpy(out + 16, in2, 8 * sizeof(float)); - memcpy(out + 24, in3, 8 * sizeof(float)); - memcpy(out + 32, in4, 8 * sizeof(float)); - memcpy(out + 40, in5, 8 * sizeof(float)); - memcpy(out + 48, in6, 8 * sizeof(float)); - memcpy(out + 56, in7, 8 * sizeof(float)); - in0 += 6; - in1 += 6; - in2 += 6; - in3 += 6; - in4 += 6; - in5 += 6; - in6 += 6; - in7 += 6; - out += 64; - } - // remain width - if (remain_w > 2) { - memcpy(out, in0, remain_w * sizeof(float)); - memcpy(out + 8, in1, remain_w * sizeof(float)); - memcpy(out + 16, in2, remain_w * sizeof(float)); - memcpy(out + 24, in3, remain_w * sizeof(float)); - memcpy(out + 32, in4, remain_w * sizeof(float)); - memcpy(out + 40, in5, remain_w * sizeof(float)); - memcpy(out + 48, in6, remain_w * sizeof(float)); - memcpy(out + 56, in7, remain_w * sizeof(float)); - out += 64; - } - in0 += 5 * width + remain_w; - in1 += 5 * width + remain_w; - in2 += 5 * width + remain_w; - in3 += 5 * width + remain_w; - in4 += 5 * width + remain_w; - in5 += 5 * width + remain_w; - in6 += 5 * width + remain_w; - in7 += 5 * width + remain_w; - } - // remain height - if (remain_h > 2) { - for (int w = 0; w < inter_w; ++w) { - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float)); - } - out += 64; - in0 += 6; - } - // remain width - if (remain_w > 2) { - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float)); - } - } - } - } - // transform tiles, compute B_T * d(c, b) * B - for (int c = 0; c < channel; ++c) { - for (int tile = 0; tile < h_tiles * w_tiles; ++tile) { - float *out = outptr + (c * h_tiles * w_tiles + tile) * 64; - // compute B_T * d(c, b) - float bd[8][8]; - for (int i = 0; i < 8; ++i) { - float d0 = out[8 * i + 0]; - float d1 = out[8 * i + 1]; - float d2 = out[8 * i + 2]; - float d3 = out[8 * i + 3]; - float d4 = out[8 * i + 4]; - float d5 = out[8 * i + 5]; - float d6 = out[8 * i + 6]; - float d7 = out[8 * i + 7]; - - bd[i][0] = d0 - d6 + (d4 - d2) * 5.25; - float v1 = d2 - 4.25 * d4 + d6; - float v2 = d1 - 4.25 * d3 + d5; - // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 - bd[i][1] = v1 + v2; - // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 - bd[i][2] = v1 - v2; - v1 = 0.25 * d2 - 1.25 * d4 + d6; - v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; - // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 - bd[i][3] = v1 + v2; - // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 - bd[i][4] = v1 - v2; - v1 = 4 * d2 - 5 * d4 + d6; - v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; - // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 - bd[i][5] = v1 + v2; - // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 - bd[i][6] = v1 - v2; - bd[i][7] = d7 - d1 + (d3 - d5) * 5.25; - } - // compute B_T * d(c, b) * B - for (int i = 0; i < 8; ++i, out += 8) { - float d0 = bd[0][i]; - float d1 = bd[1][i]; - float d2 = bd[2][i]; - float d3 = bd[3][i]; - float d4 = bd[4][i]; - float d5 = bd[5][i]; - float d6 = bd[6][i]; - float d7 = bd[7][i]; - - out[0] = d0 - d6 + (d4 - d2) * 5.25; - float v1 = d2 - 4.25 * d4 + d6; - float v2 = d1 - 4.25 * d3 + d5; - // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 - out[1] = v1 + v2; - // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 - out[2] = v1 - v2; - v1 = 0.25 * d2 - 1.25 * d4 + d6; - v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; - // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 - out[3] = v1 + v2; - // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 - out[4] = v1 - v2; - v1 = 4 * d2 - 5 * d4 + d6; - v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; - // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 - out[5] = v1 + v2; - // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 - out[6] = v1 - v2; - out[7] = d7 - d1 + (d3 - d5) * 5.25; - } - } - } -} - -template <> -void winograd_transform_output<8, 3>(const framework::Tensor &input, - const framework::Tensor &weight, - framework::Tensor *output) { - // input shape is [in_channel, h_tiles, w_tiles, 64] - // weight shape is [out_channel, in_channel, 64] - int in_channel = input.dims()[0]; - int h_tiles = input.dims()[1]; - int w_tiles = input.dims()[2]; - int tiles = h_tiles * w_tiles; - int out_channel = weight.dims()[0]; - // compute U*V first - framework::Tensor output_m; - framework::DDim shape = - framework::make_ddim(std::vector{out_channel, tiles, 64}); - float *output_m_ptr = output_m.mutable_data(shape); - memset(output_m_ptr, 0, output_m.numel() * sizeof(float)); - const float *input_ptr = input.data(); - const float *weight_ptr = weight.data(); - for (int i = 0; i < out_channel; ++i) { - for (int j = 0; j < tiles; ++j) { - const float *w_ptr = weight_ptr + i * in_channel * 64; - const float *in_ptr = input_ptr + j * 64; - float *m_ptr = output_m_ptr + (i * tiles + j) * 64; - for (int c = 0; c < in_channel; ++c) { - for (int k = 0; k < 64; ++k) { - m_ptr[k] += w_ptr[k] * in_ptr[k]; - } - w_ptr += 64; - in_ptr += tiles * 64; - } - } - } - - for (int oc = 0; oc < out_channel; ++oc) { - for (int tile = 0; tile < tiles; ++tile) { - float *m = output_m_ptr + (oc * tiles + tile) * 64; - // compute A_T * m - float am[6][8]; - for (int i = 0; i < 8; ++i) { - float d0 = m[i * 8 + 0]; - float d1 = m[i * 8 + 1]; - float d2 = m[i * 8 + 2]; - float d3 = m[i * 8 + 3]; - float d4 = m[i * 8 + 4]; - float d5 = m[i * 8 + 5]; - float d6 = m[i * 8 + 6]; - float d7 = m[i * 8 + 7]; - float v0 = d1 + d2; - float v1 = d1 - d2; - float v2 = d3 + d4; - float v3 = d3 - d4; - float v4 = d5 + d6; - float v5 = d5 - d6; - - am[0][i] = d0 + v0 + v2 + 32 * v4; - am[1][i] = v1 + 2 * v3 + 16 * v5; - am[2][i] = v0 + 4 * v2 + 8 * v4; - am[3][i] = v1 + 8 * v3 + 4 * v5; - am[4][i] = v0 + 16 * v2 + 2 * v4; - am[5][i] = v1 + 32 * v3 + v5 + d7; - } - // compute A_T * m * A - for (int i = 0; i < 6; ++i, m += 8) { - float d0 = am[i][0]; - float d1 = am[i][1]; - float d2 = am[i][2]; - float d3 = am[i][3]; - float d4 = am[i][4]; - float d5 = am[i][5]; - float d6 = am[i][6]; - float d7 = am[i][7]; - float v0 = d1 + d2; - float v1 = d1 - d2; - float v2 = d3 + d4; - float v3 = d3 - d4; - float v4 = d5 + d6; - float v5 = d5 - d6; - - m[0] = d0 + v0 + v2 + 32 * v4; - m[1] = v1 + 2 * v3 + 16 * v5; - m[2] = v0 + 4 * v2 + 8 * v4; - m[3] = v1 + 8 * v3 + 4 * v5; - m[4] = v0 + 16 * v2 + 2 * v4; - m[5] = v1 + 32 * v3 + v5 + d7; - } - } - } - - int out_h = output->dims()[2]; - int out_w = output->dims()[3]; - float *output_ptr = output->mutable_data(); - // copy valid region to final output - for (int oc = 0; oc < out_channel; ++oc) { - int inter_h = out_h / 6; - int inter_w = out_w / 6; - int remain_h = out_h - inter_h * 6; - int remain_w = out_w - inter_w * 6; - - float *out_ptr0 = output_ptr + oc * out_h * out_w; - float *out_ptr1 = out_ptr0 + out_w; - float *out_ptr2 = out_ptr1 + out_w; - float *out_ptr3 = out_ptr2 + out_w; - float *out_ptr4 = out_ptr3 + out_w; - float *out_ptr5 = out_ptr4 + out_w; - const float *m_ptr = output_m_ptr + oc * tiles * 64; - for (int tile_h = 0; tile_h < inter_h; ++tile_h) { - for (int tile_w = 0; tile_w < inter_w; ++tile_w) { - const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64; - memcpy(out_ptr0, m, 6 * sizeof(float)); - memcpy(out_ptr1, m + 8, 6 * sizeof(float)); - memcpy(out_ptr2, m + 16, 6 * sizeof(float)); - memcpy(out_ptr3, m + 24, 6 * sizeof(float)); - memcpy(out_ptr4, m + 32, 6 * sizeof(float)); - memcpy(out_ptr5, m + 40, 6 * sizeof(float)); - out_ptr0 += 6; - out_ptr1 += 6; - out_ptr2 += 6; - out_ptr3 += 6; - out_ptr4 += 6; - out_ptr5 += 6; - } - // remain w - if (remain_w > 0) { - const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64; - memcpy(out_ptr0, m, remain_w * sizeof(float)); - memcpy(out_ptr1, m + 8, remain_w * sizeof(float)); - memcpy(out_ptr2, m + 16, remain_w * sizeof(float)); - memcpy(out_ptr3, m + 24, remain_w * sizeof(float)); - memcpy(out_ptr4, m + 32, remain_w * sizeof(float)); - memcpy(out_ptr5, m + 40, remain_w * sizeof(float)); - out_ptr0 += remain_w; - out_ptr1 += remain_w; - out_ptr2 += remain_w; - out_ptr3 += remain_w; - out_ptr4 += remain_w; - out_ptr5 += remain_w; - } - out_ptr0 += 5 * out_w; - out_ptr1 += 5 * out_w; - out_ptr2 += 5 * out_w; - out_ptr3 += 5 * out_w; - out_ptr4 += 5 * out_w; - out_ptr5 += 5 * out_w; - } - // remain h - if (remain_h > 0) { - for (int tile_w = 0; tile_w < inter_w; ++tile_w) { - const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64; - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float)); - } - out_ptr0 += 6; - } - if (remain_w > 0) { - const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64; - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float)); - } - } - } - } -} - -} // namespace math -} // namespace operators -} // namespace paddle_mobile - -#endif // __aarch64__ -#endif // CONV_OP -- GitLab