diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index 28cb2c3e409327bdc3c38a000afa62c6dd163d6c..91e901b89ffa9fd6288eb9abf7bfac9d1a14a2fb 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -79,12 +79,12 @@ void ConvAddBNReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvBNReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp index b62fdf71f80780b8b73a961c4ddeaf05396d4c9e..9de6e333e7b7ec77acd068bdfc4b01136bfc80ec 100644 --- a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -11,11 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef FUSION_CONVADD_OP #include "operators/kernel/conv_add_kernel.h" #include "operators/kernel/arm/convolution/conv_common.h" #include "operators/kernel/central-arm-func/conv_add_arm_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -47,12 +49,12 @@ void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvAddBasic(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index 4060c56312aecddce58dc408a8f64893202c01e4..3f8a85e74b8670e66c20bd2f9fae277a7f1203ff 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -46,11 +46,11 @@ void ConvAddReluKernel::Compute( DepthwiseConv5x5(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::AddChannelWise(param.Output(), param.Bias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvAddReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index 86a9eb22507f52f2fe6a0a14e93b13cbd4414d48..a46c66c275a4e90b68537ecaa22bd8c655e92c0b 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -79,12 +79,12 @@ void ConvBNAddReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvBNReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index 0de0704884e95af63337fe28cf49ebcd18f5c4ae..b49120e7408e263c2eeb31d980abecfca943f30e 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -78,12 +78,12 @@ void ConvBNReluKernel::Compute( math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); math::ScaleAddChannelWise(param.Output(), param.NewScale(), param.NewBias(), param.Output()); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: ConvBNReluBasic>(param); break; diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index 3ff08eae0d236e94bcf6277ec0e1cc6750d2e6e5..89817153519ba40b60da1254d1631040e9df2819 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -53,6 +53,7 @@ void InitBaseConvKernel(ConvParam *param) { } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && param->Strides()[0] == 1) { param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; +#endif } else if (conv3x3 && !depth3x3 && param->Strides()[0] == param->Strides()[1] && param->Dilations()[0] == param->Dilations()[1] && @@ -68,7 +69,6 @@ void InitBaseConvKernel(ConvParam *param) { param->transformed_filter_ = new framework::LoDTensor; operators::math::winograd_transform_weight<8, 3>( *param->Filter(), param->transformed_filter_); -#endif } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; } diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp index 59026be4d6ed87de3653ba536d890a2a9b11f3bf..97c153fa2859491ebb5754c3f4a6a6820c6f8964 100644 --- a/src/operators/kernel/arm/convolution/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -55,10 +55,10 @@ void ConvKernel::Compute(const ConvParam ¶m) { case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: DepthwiseConv5x5(param); break; +#endif // __aarch64__ case ConvParam::EXEC_WINOGRAD3X3_FLOAT: WinogradConv3x3<8, 3>(param); break; -#endif // __aarch64__ case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index f95ead62445861566a1591df3c883085fc3eb16e..11366c93a737e96e2693c9157355c066d4125d70 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -15,10 +15,10 @@ limitations under the License. */ // Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn // project. +#if defined(__ARM_NEON) || defined(__ARM_NEON__) #ifdef CONV_OP -#ifndef __aarch64__ - +#include #include "operators/math/pad.h" #include "operators/math/winograd/winograd_transform.h" @@ -51,6 +51,10 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180}; const float *inptr = weight.data(); + +#if __aarch64__ + int remain_start = 0; +#else int remain_start = out_channel & 0xFFFC; #pragma omp parallel for @@ -256,6 +260,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, "q13", "r0"); } } +#endif // __aarch64__ // remain output channel #pragma omp parallel for @@ -358,6 +363,90 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, const float *in2 = in1 + width; const float *in3 = in2 + width; float *d_bt_ptr = d_bt; +#if __aarch64__ + int steps = 4 * width; + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4_t _q1 = vld1q_f32(transform_matrix + 4); + for (int l = 0; l < 2; ++l) { + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(in0); + _q45.val[0] = vld1q_f32(in0 + 4); + _q23.val[1] = vld1q_f32(in1); + _q45.val[1] = vld1q_f32(in1 + 4); + _q67.val[0] = vld1q_f32(in2); + _q89.val[0] = vld1q_f32(in2 + 4); + _q67.val[1] = vld1q_f32(in3); + _q89.val[1] = vld1q_f32(in3 + 4); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q67.val[0])); + float32x4_t _q4 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q67.val[1])); + float32x4_t _q3 = vcombine_f32(vget_low_f32(_q45.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q45.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q6 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q67.val[0])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q67.val[1])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q45.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q9 = vcombine_f32(vget_high_f32(_q45.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q10 = vsubq_f32(_q2, _q7); + float32x4_t _q11 = vsubq_f32(_q3, _q6); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_f32(d_bt_ptr, _q10); + + _q10 = vaddq_f32(_q6, _q7); + _q11 = vaddq_f32(_q4, _q5); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 0); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 0); + float32x4_t _q12 = vaddq_f32(_q10, _q11); + float32x4_t _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 4, _q12); + vst1q_f32(d_bt_ptr + 8, _q13); + + _q10 = vmulq_lane_f32(_q6, vget_high_f32(_q1), 1); + _q11 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 0); + _q10 = vaddq_f32(_q10, _q7); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_low_f32(_q1), 1); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1); + _q12 = vaddq_f32(_q10, _q11); + _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 12, _q12); + vst1q_f32(d_bt_ptr + 16, _q13); + + _q10 = vmulq_lane_f32(_q6, vget_low_f32(_q1), 0); + _q11 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 1); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1); + _q10 = vmlaq_lane_f32(_q10, _q7, vget_high_f32(_q1), 0); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q1), 0); + _q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0); + _q12 = vaddq_f32(_q10, _q11); + _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 20, _q12); + vst1q_f32(d_bt_ptr + 24, _q13); + + _q10 = vsubq_f32(_q9, _q4); + _q11 = vsubq_f32(_q8, _q5); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_f32(d_bt_ptr + 28, _q10); + + in0 += steps; + in1 += steps; + in2 += steps; + in3 += steps; + d_bt_ptr += 32; + } +#else int steps = 4 * width * sizeof(float); asm volatile( "vld1.32 {d0-d3}, [%[tm_ptr]] \n" @@ -434,7 +523,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); - +#endif // __aarch64__ float *ptr0 = d_bt; float *ptr1 = ptr0 + 32; int tile_indics = h * w_tiles + w; @@ -450,6 +539,120 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, float *out5 = out4 + channel * 8; float *out6 = out5 + channel * 8; float *out7 = out6 + channel * 8; +#if __aarch64__ + steps = 8 * channel * 8; + for (int l = 0; l < 2; ++l) { + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(ptr0); + _q23.val[0] = vld1q_f32(ptr0 + 4); + _q45.val[1] = vld1q_f32(ptr0 + 8); + _q45.val[1] = vld1q_f32(ptr0 + 12); + _q67.val[0] = vld1q_f32(ptr1); + _q67.val[0] = vld1q_f32(ptr1 + 4); + _q89.val[1] = vld1q_f32(ptr1 + 8); + _q89.val[1] = vld1q_f32(ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q9 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q10 = vsubq_f32(_q2, _q8); + float32x4_t _q11 = vsubq_f32(_q6, _q4); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_lane_f32(out0, _q10, 0); + vst1q_lane_f32(out0 + steps, _q10, 1); + vst1q_lane_f32(out0 + 2 * steps, _q10, 2); + vst1q_lane_f32(out0 + 3 * steps, _q10, 3); + + _q10 = vaddq_f32(_q4, _q8); + _q11 = vaddq_f32(_q3, _q7); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 0); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 0); + float32x4_t _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out1, _q12, 0); + vst1q_lane_f32(out1 + steps, _q12, 1); + vst1q_lane_f32(out1 + 2 * steps, _q12, 2); + vst1q_lane_f32(out1 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out2, _q12, 0); + vst1q_lane_f32(out2 + steps, _q12, 1); + vst1q_lane_f32(out2 + 2 * steps, _q12, 2); + vst1q_lane_f32(out2 + 3 * steps, _q12, 3); + + _q10 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 1); + _q11 = vmulq_lane_f32(_q3, vget_high_f32(_q1), 0); + _q10 = vaddq_f32(_q10, _q8); + _q11 = vmlaq_lane_f32(_q11, _q7, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_low_f32(_q1), 1); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1); + _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out3, _q12, 0); + vst1q_lane_f32(out3 + steps, _q12, 1); + vst1q_lane_f32(out3 + 2 * steps, _q12, 2); + vst1q_lane_f32(out3 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out4, _q12, 0); + vst1q_lane_f32(out4 + steps, _q12, 1); + vst1q_lane_f32(out4 + 2 * steps, _q12, 2); + vst1q_lane_f32(out4 + 3 * steps, _q12, 3); + + _q10 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0); + _q11 = vmulq_lane_f32(_q3, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 1); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1); + _q10 = vmlaq_lane_f32(_q10, _q8, vget_high_f32(_q1), 0); + _q11 = vmlaq_lane_f32(_q11, _q7, vget_high_f32(_q1), 0); + _q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0); + _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out5, _q12, 0); + vst1q_lane_f32(out5 + steps, _q12, 1); + vst1q_lane_f32(out5 + 2 * steps, _q12, 2); + vst1q_lane_f32(out5 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out6, _q12, 0); + vst1q_lane_f32(out6 + steps, _q12, 1); + vst1q_lane_f32(out6 + 2 * steps, _q12, 2); + vst1q_lane_f32(out6 + 3 * steps, _q12, 3); + + _q10 = vsubq_f32(_q9, _q3); + _q11 = vsubq_f32(_q5, _q7); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_lane_f32(out7, _q10, 0); + vst1q_lane_f32(out7 + steps, _q10, 1); + vst1q_lane_f32(out7 + 2 * steps, _q10, 2); + vst1q_lane_f32(out7 + 3 * steps, _q10, 3); + + ptr0 += 16; + ptr1 += 16; + out0 += 4 * steps; + out1 += 4 * steps; + out2 += 4 * steps; + out3 += 4 * steps; + out4 += 4 * steps; + out5 += 4 * steps; + out6 += 4 * steps; + out7 += 4 * steps; + } +#else steps = 8 * channel * 8 * sizeof(float); asm volatile( "mov r0, #2 \n" @@ -555,6 +758,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); +#endif // __aarch64__ } } } @@ -587,6 +791,71 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8; int inter_channel = in_channel >> 1; int remain_channel = in_channel & 0x1; +#if __aarch64__ + asm volatile( + "dup v8.4s, wzr \n" + "dup v9.4s, wzr \n" + "dup v10.4s, wzr \n" + "dup v11.4s, wzr \n" + "dup v12.4s, wzr \n" + "dup v13.4s, wzr \n" + "dup v14.4s, wzr \n" + "dup v15.4s, wzr \n" + + "cmp %[inter], #0 \n" + "ble loop_1c_%= \n" + // loop 2 channels + "loop_2c_%=: \n" + "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[in_ptr]], #32 \n" + + "fmla v8.4s, v2.4s, v0.s[0] \n" + "fmla v9.4s, v3.4s, v0.s[0] \n" + "fmla v10.4s, v2.4s, v0.s[1] \n" + "fmla v11.4s, v3.4s, v0.s[1] \n" + "fmla v12.4s, v2.4s, v0.s[2] \n" + "fmla v13.4s, v3.4s, v0.s[2] \n" + "fmla v14.4s, v2.4s, v0.s[3] \n" + "fmla v15.4s, v3.4s, v0.s[3] \n" + + "fmla v8.4s, v4.4s, v1.s[0] \n" + "fmla v9.4s, v5.4s, v1.s[0] \n" + "fmla v10.4s, v4.4s, v1.s[1] \n" + "fmla v11.4s, v5.4s, v1.s[1] \n" + "fmla v12.4s, v4.4s, v1.s[2] \n" + "fmla v13.4s, v5.4s, v1.s[2] \n" + "fmla v14.4s, v4.4s, v1.s[3] \n" + "fmla v15.4s, v5.4s, v1.s[3] \n" + + "subs %[inter], %[inter], #1 \n" + "bne loop_2c_%= \n" + + // loop 1 channel + "loop_1c_%=: \n" + "cmp %[remain], #0 \n" + "ble store_res_%= \n" + + "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" + "fmla v8.4s, v2.4s, v0.s[0] \n" + "fmla v9.4s, v3.4s, v0.s[0] \n" + "fmla v10.4s, v2.4s, v0.s[1] \n" + "fmla v11.4s, v3.4s, v0.s[1] \n" + "fmla v12.4s, v2.4s, v0.s[2] \n" + "fmla v13.4s, v3.4s, v0.s[2] \n" + "fmla v14.4s, v2.4s, v0.s[3] \n" + "fmla v15.4s, v3.4s, v0.s[3] \n" + + "store_res_%=: \n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[uv_ptr]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[uv_ptr]], #64 \n" + : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), + [inter] "+r"(inter_channel) + : [remain] "r"(remain_channel) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"); +#else asm volatile( "veor q8, q8, q8 \n" "veor q9, q9, q9 \n" @@ -651,6 +920,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [remain_channel] "r"(remain_channel) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ } } } @@ -686,6 +956,116 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, int tile_block = tile_indics >> 3; int block_indics = tile_indics & 0x7; const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + for (int l = 0; l < 2; ++l) { + float32x4_t _q1, _q2, _q3, _q4, _q5, _q6, _q7, _q8; + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 0); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 0); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 0); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 0); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 0); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 0); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 0); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 0); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 1); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 1); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 1); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 1); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 1); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 1); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 1); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 1); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 2); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 2); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 2); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 2); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 2); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 2); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 2); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 2); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 3); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 3); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 3); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 3); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 3); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 3); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 3); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 3); + uv_ptr0 += 32; + + float32x4_t _q9 = vaddq_f32(_q3, _q5); + float32x4_t _q10 = vaddq_f32(_q7, _q2); + float32x4_t _q11 = vaddq_f32(_q4, _q6); + float32x4_t _q12 = vsubq_f32(_q3, _q5); + float32x4_t _q13 = vsubq_f32(_q7, _q2); + float32x4_t _q14 = vsubq_f32(_q4, _q6); + _q2 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q3 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + float32x4_t _q15 = vaddq_f32(_q1, _q9); + _q15 = vaddq_f32(_q15, _q10); + _q15 = vmlaq_lane_f32(_q15, _q3, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr, _q15); + + _q15 = vaddq_f32(_q12, _q2); + _q15 = vmlaq_lane_f32(_q15, _q14, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 4, _q15); + + _q15 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_high_f32(_q0), 0); + vst1q_f32(at_m_ptr + 8, _q15); + + _q15 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q15 = vmlaq_lane_f32(_q15, _q14, vget_low_f32(_q0), 1); + vst1q_f32(at_m_ptr + 12, _q15); + + _q15 = vaddq_f32(_q9, _q3); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 16, _q15); + + _q15 = vaddq_f32(_q12, _q8); + _q15 = vaddq_f32(_q15, _q14); + _q15 = vmlaq_lane_f32(_q15, _q2, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 20, _q15); + + at_m_ptr += 24; + } +#else int steps = 32 * sizeof(float); asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" @@ -771,6 +1151,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ float *at_m_ptr0 = at_m; float *at_m_ptr1 = at_m + 24; @@ -782,6 +1163,133 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float *out_ptr3 = output_tmp + 18; float *out_ptr4 = output_tmp + 24; float *out_ptr5 = output_tmp + 30; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(at_m_ptr0); + _q23.val[0] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 8); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); + _q67.val[0] = vld1q_f32(at_m_ptr1); + _q67.val[0] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 8); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q9 = vaddq_f32(_q2, _q3); + float32x4_t _q10 = vaddq_f32(_q4, _q5); + float32x4_t _q11 = vaddq_f32(_q6, _q7); + float32x4_t _q12 = vsubq_f32(_q2, _q3); + float32x4_t _q13 = vsubq_f32(_q4, _q5); + float32x4_t _q14 = vsubq_f32(_q6, _q7); + _q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + _q1 = vaddq_f32(_q1, _q9); + _q1 = vaddq_f32(_q1, _q10); + _q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1); + + _q2 = vaddq_f32(_q12, _q6); + _q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1); + + _q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0); + + _q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1); + + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + vst1_f32(out_ptr0, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0])); + vst1_f32(out_ptr1, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1])); + vst1_f32(out_ptr2, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0])); + vst1_f32(out_ptr3, vget_high_f32(_q23.val[1])); + vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1])); + + _q1 = vaddq_f32(_q9, _q7); + _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q8); + _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); + _q23 = vtrnq_f32(_q1, _q2); + vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1])); + + // remain 2 rows + _q1 = vld1q_f32(at_m_ptr0 + 16); + _q2 = vld1q_f32(at_m_ptr0 + 20); + _q3 = vld1q_f32(at_m_ptr1 + 16); + _q4 = vld1q_f32(at_m_ptr1 + 20); + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + + float32x2_t _d2 = vget_low_f32(_q23.val[0]); + float32x2_t _d3 = vget_high_f32(_q23.val[0]); + float32x2_t _d4 = vget_low_f32(_q23.val[1]); + float32x2_t _d5 = vget_high_f32(_q23.val[1]); + float32x2_t _d6 = vget_low_f32(_q45.val[0]); + float32x2_t _d7 = vget_high_f32(_q45.val[0]); + float32x2_t _d8 = vget_low_f32(_q45.val[1]); + float32x2_t _d9 = vget_high_f32(_q45.val[1]); + + float32x2_t _d10 = vadd_f32(_d4, _d3); + float32x2_t _d11 = vadd_f32(_d5, _d6); + float32x2_t _d12 = vadd_f32(_d8, _d7); + float32x2_t _d13 = vsub_f32(_d4, _d3); + float32x2_t _d14 = vsub_f32(_d5, _d6); + float32x2_t _d15 = vsub_f32(_d8, _d7); + float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0); + float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0); + + float32x2_t _d18 = vadd_f32(_d2, _d10); + float32x2_t _d20 = vadd_f32(_d13, _d16); + float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1); + float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0); + _d18 = vadd_f32(_d18, _d11); + _d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1); + _d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1); + _d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0); + _d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1); + + float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20); + float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21); + vst1_f32(out_ptr4, _d18d20.val[0]); + vst1_f32(out_ptr4 + 2, _d19d21.val[0]); + vst1_f32(out_ptr5, _d18d20.val[1]); + vst1_f32(out_ptr5 + 2, _d19d21.val[1]); + + _d18 = vadd_f32(_d10, _d17); + _d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1); + _d20 = vadd_f32(_d13, _d9); + _d20 = vadd_f32(_d20, _d15); + _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); + _d18d20 = vtrn_f32(_d18, _d20); + vst1_f32(out_ptr4 + 4, _d18); + vst1_f32(out_ptr5 + 4, _d20); +#else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" // process 4 rows @@ -898,6 +1406,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; float *out_ptr = output_ptr + offset; int remain_row = out_h - 6 * tile_h; @@ -915,6 +1424,130 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float *out_ptr3 = out_ptr2 + out_w; float *out_ptr4 = out_ptr3 + out_w; float *out_ptr5 = out_ptr4 + out_w; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(at_m_ptr0); + _q23.val[0] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 8); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); + _q67.val[0] = vld1q_f32(at_m_ptr1); + _q67.val[0] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 8); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q9 = vaddq_f32(_q2, _q3); + float32x4_t _q10 = vaddq_f32(_q4, _q5); + float32x4_t _q11 = vaddq_f32(_q6, _q7); + float32x4_t _q12 = vsubq_f32(_q2, _q3); + float32x4_t _q13 = vsubq_f32(_q4, _q5); + float32x4_t _q14 = vsubq_f32(_q6, _q7); + _q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + _q1 = vaddq_f32(_q1, _q9); + _q1 = vaddq_f32(_q1, _q10); + _q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q6); + _q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1); + + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + vst1_f32(out_ptr0, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0])); + vst1_f32(out_ptr1, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1])); + vst1_f32(out_ptr2, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0])); + vst1_f32(out_ptr3, vget_high_f32(_q23.val[1])); + vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1])); + + _q1 = vaddq_f32(_q9, _q7); + _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q8); + _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); + _q23 = vtrnq_f32(_q1, _q2); + vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1])); + + // remain 2 rows + _q1 = vld1q_f32(at_m_ptr0 + 16); + _q2 = vld1q_f32(at_m_ptr0 + 20); + _q3 = vld1q_f32(at_m_ptr1 + 16); + _q4 = vld1q_f32(at_m_ptr1 + 20); + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + + float32x2_t _d2 = vget_low_f32(_q23.val[0]); + float32x2_t _d3 = vget_high_f32(_q23.val[0]); + float32x2_t _d4 = vget_low_f32(_q23.val[1]); + float32x2_t _d5 = vget_high_f32(_q23.val[1]); + float32x2_t _d6 = vget_low_f32(_q45.val[0]); + float32x2_t _d7 = vget_high_f32(_q45.val[0]); + float32x2_t _d8 = vget_low_f32(_q45.val[1]); + float32x2_t _d9 = vget_high_f32(_q45.val[1]); + + float32x2_t _d10 = vadd_f32(_d4, _d3); + float32x2_t _d11 = vadd_f32(_d5, _d6); + float32x2_t _d12 = vadd_f32(_d8, _d7); + float32x2_t _d13 = vsub_f32(_d4, _d3); + float32x2_t _d14 = vsub_f32(_d5, _d6); + float32x2_t _d15 = vsub_f32(_d8, _d7); + float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0); + float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0); + + float32x2_t _d18 = vadd_f32(_d2, _d10); + float32x2_t _d20 = vadd_f32(_d13, _d16); + float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1); + float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0); + _d18 = vadd_f32(_d18, _d11); + _d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1); + _d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1); + _d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0); + _d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1); + + float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20); + float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21); + vst1_f32(out_ptr4, _d18d20.val[0]); + vst1_f32(out_ptr4 + 2, _d19d21.val[0]); + vst1_f32(out_ptr5, _d18d20.val[1]); + vst1_f32(out_ptr5 + 2, _d19d21.val[1]); + + _d18 = vadd_f32(_d10, _d17); + _d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1); + _d20 = vadd_f32(_d13, _d9); + _d20 = vadd_f32(_d20, _d15); + _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); + _d18d20 = vtrn_f32(_d18, _d20); + vst1_f32(out_ptr4 + 4, _d18); + vst1_f32(out_ptr5 + 4, _d20); +#else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" // process 4 rows @@ -1031,6 +1664,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ } } } @@ -1041,5 +1675,5 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, } // namespace operators } // namespace paddle_mobile -#endif // __aarch64__ #endif // CONV_OP +#endif // __ARM_NEON__ diff --git a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp b/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp deleted file mode 100644 index 5ef9c194f23ba28791f673137313aacc262d39d5..0000000000000000000000000000000000000000 --- a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// We refer https://github.com/andravin/wincnn to access the winograd transform -// matrixs - -#ifdef CONV_OP -#ifdef __aarch64__ - -#include "operators/math/winograd/winograd_transform.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -template <> -void winograd_transform_weight<8, 3>(const framework::Tensor &weight, - framework::Tensor *output) { - // weight shape is [out_channel, in_channel, kernel_h, kernel_w] - int out_channel = weight.dims()[0]; - int in_channel = weight.dims()[1]; - // reshape and alloc transformed weight - framework::DDim transformed_shape = - framework::make_ddim(std::vector{out_channel, in_channel, 64}); - float *outptr = output->mutable_data(transformed_shape); - const float *inptr = weight.data(); - for (int oc = 0; oc < out_channel; ++oc) { - for (int ic = 0; ic < in_channel; ++ic) { - size_t offset = oc * in_channel + ic; - float *kout = outptr + offset * 64; - const float *k = inptr + offset * 9; - - float gw[3][8]; - for (int i = 0; i < 3; ++i, k += 3) { - float g0 = k[0]; - float g1 = k[1]; - float g2 = k[2]; - float d0 = g0 + g2; - float d1 = g0 + 4 * g2; - float d2 = g2 + 4 * g0; - float d3 = 2 * g1; - gw[i][0] = g0; - gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2) - gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2) - gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2) - gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2) - gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2) - gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2) - gw[i][7] = g2; - } - for (int i = 0; i < 8; ++i, kout += 8) { - float g0 = gw[0][i]; - float g1 = gw[1][i]; - float g2 = gw[2][i]; - float d0 = g0 + g2; - float d1 = g0 + 4 * g2; - float d2 = g2 + 4 * g0; - float d3 = 2 * g1; - kout[0] = g0; - kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2) - kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2) - kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2) - kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2) - kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2) - kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2) - kout[7] = g2; - } - } - } -} - -template <> -void winograd_transform_input<8, 3>(const framework::Tensor &input, - framework::Tensor *output) { - // tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation - int channel = input.dims()[1]; - int height = input.dims()[2]; - int width = input.dims()[3]; - int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6 - int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6 - framework::DDim transformed_shape = - framework::make_ddim(std::vector{channel, h_tiles, w_tiles, 64}); - float *outptr = output->mutable_data(transformed_shape); - memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float)); - const float *inptr = input.data(); - // pack input to tiles - for (int c = 0; c < channel; ++c) { - int inter_h = (height - 2) / 6; - int inter_w = (width - 2) / 6; - int remain_h = height - (inter_h * 6); - int remain_w = width - (inter_w * 6); - const float *in0 = inptr + c * height * width; - const float *in1 = in0 + width; - const float *in2 = in1 + width; - const float *in3 = in2 + width; - const float *in4 = in3 + width; - const float *in5 = in4 + width; - const float *in6 = in5 + width; - const float *in7 = in6 + width; - float *out = outptr + c * h_tiles * w_tiles * 64; - - for (int h = 0; h < inter_h; ++h) { - for (int w = 0; w < inter_w; ++w) { - memcpy(out, in0, 8 * sizeof(float)); - memcpy(out + 8, in1, 8 * sizeof(float)); - memcpy(out + 16, in2, 8 * sizeof(float)); - memcpy(out + 24, in3, 8 * sizeof(float)); - memcpy(out + 32, in4, 8 * sizeof(float)); - memcpy(out + 40, in5, 8 * sizeof(float)); - memcpy(out + 48, in6, 8 * sizeof(float)); - memcpy(out + 56, in7, 8 * sizeof(float)); - in0 += 6; - in1 += 6; - in2 += 6; - in3 += 6; - in4 += 6; - in5 += 6; - in6 += 6; - in7 += 6; - out += 64; - } - // remain width - if (remain_w > 2) { - memcpy(out, in0, remain_w * sizeof(float)); - memcpy(out + 8, in1, remain_w * sizeof(float)); - memcpy(out + 16, in2, remain_w * sizeof(float)); - memcpy(out + 24, in3, remain_w * sizeof(float)); - memcpy(out + 32, in4, remain_w * sizeof(float)); - memcpy(out + 40, in5, remain_w * sizeof(float)); - memcpy(out + 48, in6, remain_w * sizeof(float)); - memcpy(out + 56, in7, remain_w * sizeof(float)); - out += 64; - } - in0 += 5 * width + remain_w; - in1 += 5 * width + remain_w; - in2 += 5 * width + remain_w; - in3 += 5 * width + remain_w; - in4 += 5 * width + remain_w; - in5 += 5 * width + remain_w; - in6 += 5 * width + remain_w; - in7 += 5 * width + remain_w; - } - // remain height - if (remain_h > 2) { - for (int w = 0; w < inter_w; ++w) { - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float)); - } - out += 64; - in0 += 6; - } - // remain width - if (remain_w > 2) { - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float)); - } - } - } - } - // transform tiles, compute B_T * d(c, b) * B - for (int c = 0; c < channel; ++c) { - for (int tile = 0; tile < h_tiles * w_tiles; ++tile) { - float *out = outptr + (c * h_tiles * w_tiles + tile) * 64; - // compute B_T * d(c, b) - float bd[8][8]; - for (int i = 0; i < 8; ++i) { - float d0 = out[8 * i + 0]; - float d1 = out[8 * i + 1]; - float d2 = out[8 * i + 2]; - float d3 = out[8 * i + 3]; - float d4 = out[8 * i + 4]; - float d5 = out[8 * i + 5]; - float d6 = out[8 * i + 6]; - float d7 = out[8 * i + 7]; - - bd[i][0] = d0 - d6 + (d4 - d2) * 5.25; - float v1 = d2 - 4.25 * d4 + d6; - float v2 = d1 - 4.25 * d3 + d5; - // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 - bd[i][1] = v1 + v2; - // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 - bd[i][2] = v1 - v2; - v1 = 0.25 * d2 - 1.25 * d4 + d6; - v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; - // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 - bd[i][3] = v1 + v2; - // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 - bd[i][4] = v1 - v2; - v1 = 4 * d2 - 5 * d4 + d6; - v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; - // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 - bd[i][5] = v1 + v2; - // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 - bd[i][6] = v1 - v2; - bd[i][7] = d7 - d1 + (d3 - d5) * 5.25; - } - // compute B_T * d(c, b) * B - for (int i = 0; i < 8; ++i, out += 8) { - float d0 = bd[0][i]; - float d1 = bd[1][i]; - float d2 = bd[2][i]; - float d3 = bd[3][i]; - float d4 = bd[4][i]; - float d5 = bd[5][i]; - float d6 = bd[6][i]; - float d7 = bd[7][i]; - - out[0] = d0 - d6 + (d4 - d2) * 5.25; - float v1 = d2 - 4.25 * d4 + d6; - float v2 = d1 - 4.25 * d3 + d5; - // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 - out[1] = v1 + v2; - // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 - out[2] = v1 - v2; - v1 = 0.25 * d2 - 1.25 * d4 + d6; - v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; - // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 - out[3] = v1 + v2; - // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 - out[4] = v1 - v2; - v1 = 4 * d2 - 5 * d4 + d6; - v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; - // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 - out[5] = v1 + v2; - // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 - out[6] = v1 - v2; - out[7] = d7 - d1 + (d3 - d5) * 5.25; - } - } - } -} - -template <> -void winograd_transform_output<8, 3>(const framework::Tensor &input, - const framework::Tensor &weight, - framework::Tensor *output) { - // input shape is [in_channel, h_tiles, w_tiles, 64] - // weight shape is [out_channel, in_channel, 64] - int in_channel = input.dims()[0]; - int h_tiles = input.dims()[1]; - int w_tiles = input.dims()[2]; - int tiles = h_tiles * w_tiles; - int out_channel = weight.dims()[0]; - // compute U*V first - framework::Tensor output_m; - framework::DDim shape = - framework::make_ddim(std::vector{out_channel, tiles, 64}); - float *output_m_ptr = output_m.mutable_data(shape); - memset(output_m_ptr, 0, output_m.numel() * sizeof(float)); - const float *input_ptr = input.data(); - const float *weight_ptr = weight.data(); - for (int i = 0; i < out_channel; ++i) { - for (int j = 0; j < tiles; ++j) { - const float *w_ptr = weight_ptr + i * in_channel * 64; - const float *in_ptr = input_ptr + j * 64; - float *m_ptr = output_m_ptr + (i * tiles + j) * 64; - for (int c = 0; c < in_channel; ++c) { - for (int k = 0; k < 64; ++k) { - m_ptr[k] += w_ptr[k] * in_ptr[k]; - } - w_ptr += 64; - in_ptr += tiles * 64; - } - } - } - - for (int oc = 0; oc < out_channel; ++oc) { - for (int tile = 0; tile < tiles; ++tile) { - float *m = output_m_ptr + (oc * tiles + tile) * 64; - // compute A_T * m - float am[6][8]; - for (int i = 0; i < 8; ++i) { - float d0 = m[i * 8 + 0]; - float d1 = m[i * 8 + 1]; - float d2 = m[i * 8 + 2]; - float d3 = m[i * 8 + 3]; - float d4 = m[i * 8 + 4]; - float d5 = m[i * 8 + 5]; - float d6 = m[i * 8 + 6]; - float d7 = m[i * 8 + 7]; - float v0 = d1 + d2; - float v1 = d1 - d2; - float v2 = d3 + d4; - float v3 = d3 - d4; - float v4 = d5 + d6; - float v5 = d5 - d6; - - am[0][i] = d0 + v0 + v2 + 32 * v4; - am[1][i] = v1 + 2 * v3 + 16 * v5; - am[2][i] = v0 + 4 * v2 + 8 * v4; - am[3][i] = v1 + 8 * v3 + 4 * v5; - am[4][i] = v0 + 16 * v2 + 2 * v4; - am[5][i] = v1 + 32 * v3 + v5 + d7; - } - // compute A_T * m * A - for (int i = 0; i < 6; ++i, m += 8) { - float d0 = am[i][0]; - float d1 = am[i][1]; - float d2 = am[i][2]; - float d3 = am[i][3]; - float d4 = am[i][4]; - float d5 = am[i][5]; - float d6 = am[i][6]; - float d7 = am[i][7]; - float v0 = d1 + d2; - float v1 = d1 - d2; - float v2 = d3 + d4; - float v3 = d3 - d4; - float v4 = d5 + d6; - float v5 = d5 - d6; - - m[0] = d0 + v0 + v2 + 32 * v4; - m[1] = v1 + 2 * v3 + 16 * v5; - m[2] = v0 + 4 * v2 + 8 * v4; - m[3] = v1 + 8 * v3 + 4 * v5; - m[4] = v0 + 16 * v2 + 2 * v4; - m[5] = v1 + 32 * v3 + v5 + d7; - } - } - } - - int out_h = output->dims()[2]; - int out_w = output->dims()[3]; - float *output_ptr = output->mutable_data(); - // copy valid region to final output - for (int oc = 0; oc < out_channel; ++oc) { - int inter_h = out_h / 6; - int inter_w = out_w / 6; - int remain_h = out_h - inter_h * 6; - int remain_w = out_w - inter_w * 6; - - float *out_ptr0 = output_ptr + oc * out_h * out_w; - float *out_ptr1 = out_ptr0 + out_w; - float *out_ptr2 = out_ptr1 + out_w; - float *out_ptr3 = out_ptr2 + out_w; - float *out_ptr4 = out_ptr3 + out_w; - float *out_ptr5 = out_ptr4 + out_w; - const float *m_ptr = output_m_ptr + oc * tiles * 64; - for (int tile_h = 0; tile_h < inter_h; ++tile_h) { - for (int tile_w = 0; tile_w < inter_w; ++tile_w) { - const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64; - memcpy(out_ptr0, m, 6 * sizeof(float)); - memcpy(out_ptr1, m + 8, 6 * sizeof(float)); - memcpy(out_ptr2, m + 16, 6 * sizeof(float)); - memcpy(out_ptr3, m + 24, 6 * sizeof(float)); - memcpy(out_ptr4, m + 32, 6 * sizeof(float)); - memcpy(out_ptr5, m + 40, 6 * sizeof(float)); - out_ptr0 += 6; - out_ptr1 += 6; - out_ptr2 += 6; - out_ptr3 += 6; - out_ptr4 += 6; - out_ptr5 += 6; - } - // remain w - if (remain_w > 0) { - const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64; - memcpy(out_ptr0, m, remain_w * sizeof(float)); - memcpy(out_ptr1, m + 8, remain_w * sizeof(float)); - memcpy(out_ptr2, m + 16, remain_w * sizeof(float)); - memcpy(out_ptr3, m + 24, remain_w * sizeof(float)); - memcpy(out_ptr4, m + 32, remain_w * sizeof(float)); - memcpy(out_ptr5, m + 40, remain_w * sizeof(float)); - out_ptr0 += remain_w; - out_ptr1 += remain_w; - out_ptr2 += remain_w; - out_ptr3 += remain_w; - out_ptr4 += remain_w; - out_ptr5 += remain_w; - } - out_ptr0 += 5 * out_w; - out_ptr1 += 5 * out_w; - out_ptr2 += 5 * out_w; - out_ptr3 += 5 * out_w; - out_ptr4 += 5 * out_w; - out_ptr5 += 5 * out_w; - } - // remain h - if (remain_h > 0) { - for (int tile_w = 0; tile_w < inter_w; ++tile_w) { - const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64; - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float)); - } - out_ptr0 += 6; - } - if (remain_w > 0) { - const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64; - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float)); - } - } - } - } -} - -} // namespace math -} // namespace operators -} // namespace paddle_mobile - -#endif // __aarch64__ -#endif // CONV_OP