diff --git a/src/common/types.h b/src/common/types.h index e3c65a74ba459ca6e76e5bab2b30b09fb77a791f..ee2ff998ecd3dd9af5ec78b35b595f5552f9c042 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -145,6 +145,18 @@ struct PaddleMobileConfigInternal { std::string model_obfuscate_key = ""; }; +enum ARMArch { + APPLE = 0, + A53 = 53, + A55 = 55, + A57 = 57, + A72 = 72, + A73 = 73, + A75 = 75, + A76 = 76, + ARM_UNKOWN = -1 +}; + extern const char *G_OP_TYPE_CONV; extern const char *G_OP_TYPE_BATCHNORM; extern const char *G_OP_TYPE_BOX_CODER; diff --git a/src/framework/context.cpp b/src/framework/context.cpp index c7319ba02cd8e9516201bff4ea12da3224d7b06e..66ac0d088a945f5d035967ae6739e5441cd2088f 100644 --- a/src/framework/context.cpp +++ b/src/framework/context.cpp @@ -261,7 +261,8 @@ int set_sched_affinity(const std::vector &cpu_ids) { return 0; } -int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, +int get_cpu_info_by_name(int *cpu_num, ARMArch *arch, + std::vector *big_core_ids, std::vector *little_core_ids, std::vector *l1_cache_sizes, std::vector *l2_cache_sizes, @@ -270,6 +271,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, /* Snapdragon */ if (hardware_name.find("SDM845") != std::string::npos) { // 845 *cpu_num = 8; + *arch = A75; *big_core_ids = {4, 5, 6, 7}; *little_core_ids = {0, 1, 2, 3}; l1_cache_sizes->resize(*cpu_num); @@ -282,6 +284,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, return 0; } else if (hardware_name.find("SDM710") != std::string::npos) { // 710 *cpu_num = 8; + *arch = A75; *big_core_ids = {6, 7}; *little_core_ids = {0, 1, 2, 3, 4, 5}; l1_cache_sizes->resize(*cpu_num); @@ -295,6 +298,7 @@ int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, return 0; } else if (hardware_name.find("MSM8998") != std::string::npos) { // 835 *cpu_num = 8; + *arch = A73; *big_core_ids = {4, 5, 6, 7}; *little_core_ids = {0, 1, 2, 3}; l1_cache_sizes->resize(*cpu_num); @@ -313,8 +317,9 @@ int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, return 0; } else if (hardware_name.find("MSM8976") != std::string::npos) { // 652,653 *cpu_num = 8; - *big_core_ids = {0, 1, 2, 3, 4, 5, 6, 7}; - *little_core_ids = {}; + *arch = A72; + *big_core_ids = {4, 5, 6, 7}; + *little_core_ids = {0, 1, 2, 3}; l1_cache_sizes->resize(*cpu_num); l2_cache_sizes->resize(*cpu_num); l3_cache_sizes->resize(*cpu_num); @@ -322,6 +327,42 @@ int get_cpu_info_by_name(int *cpu_num, std::vector *big_core_ids, fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024); fill_cpu_cache_size(l3_cache_sizes, 0); return 0; + } else if (hardware_name.find("SDM660") != std::string::npos || + hardware_name.find("SDM636") != std::string::npos) { // 660, 636 + *cpu_num = 8; + *arch = A73; + *big_core_ids = {4, 5, 6, 7}; + *little_core_ids = {0, 1, 2, 3}; + l1_cache_sizes->resize(*cpu_num); + l2_cache_sizes->resize(*cpu_num); + l3_cache_sizes->resize(*cpu_num); + fill_cpu_cache_size(l1_cache_sizes, 64 * 1024); + fill_cpu_cache_size(l2_cache_sizes, 1024 * 1024); + fill_cpu_cache_size(l3_cache_sizes, 0); + return 0; + + /* MediaTek */ + } else if (hardware_name.find("MT6799") != std::string::npos) { // X30 + *cpu_num = 10; + *arch = A73; + *big_core_ids = {8, 9}; + *little_core_ids = {0, 1, 2, 3, 4, 5, 6, 7}; + return 0; + } else if (hardware_name.find("MT6771") != std::string::npos) { // P60 + *cpu_num = 8; + *arch = A73; + *big_core_ids = {4, 5, 6, 7}; + *little_core_ids = {0, 1, 2, 3}; + return 0; + + /* Kirin */ + } else if (hardware_name.find("KIRIN970") != + std::string::npos) { // Kirin 970 + *cpu_num = 8; + *arch = A73; + *big_core_ids = {4, 5, 6, 7}; + *little_core_ids = {0, 1, 2, 3}; + return 0; } return -1; } @@ -410,7 +451,7 @@ CPUContext::CPUContext() { // probe cpu info, and set big&litte clusters, L1, L2 and L3 cache sizes std::string cpu_name = get_cpu_name(); bool failed = - get_cpu_info_by_name(&_cpu_num, &_big_core_ids, &_little_core_ids, + get_cpu_info_by_name(&_cpu_num, &_arch, &_big_core_ids, &_little_core_ids, &_l1_cache_sizes, &_l2_cache_sizes, &_l3_cache_sizes, cpu_name) != 0; if (failed) { diff --git a/src/framework/context.h b/src/framework/context.h index 4efab6c3a0427e18ee404b7a0ac1d158e26aaa7f..b3164e5f80b86650d8030bec989f2db7583e4bc3 100644 --- a/src/framework/context.h +++ b/src/framework/context.h @@ -43,12 +43,14 @@ struct CPUContext { int get_thread_num(); PowerMode get_power_mode() const { return _power_mode; } int get_cache_size(int level); + ARMArch get_arch() const { return _arch; } int get_l1_cache_size() { return get_cache_size(1); } int get_l2_cache_size() { return get_cache_size(2); } int get_l3_cache_size() { return get_cache_size(3); } void* get_work_space(int size_in_byte); int _cpu_num; + ARMArch _arch; PowerMode _power_mode; std::vector _big_core_ids; std::vector _little_core_ids; diff --git a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index fc6e6b03172021f0cca9d5c3d97ef79f54a150b6..a40e8c4a2bb78b34acbeb1004e24bd315fd4d053 100644 --- a/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -126,6 +126,9 @@ void ConvAddBNReluKernel::Compute( case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp index b16c446592e184ee27ead987e5703cc5b7a10aeb..20474a904f554a2b138053835992918b3abe914b 100644 --- a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -44,6 +44,9 @@ void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp index 24ed34351d59bf4e6f555f9e4b888608cd7880e8..bfdd58e944d28aec6292cd20cf11891fc9449a15 100644 --- a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -45,6 +45,9 @@ void ConvAddReluKernel::Compute( case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index 509bfd4df456116fb9fe27762978105dcc42f54b..6df3cb7f1bd19b7551481a0c7fe312648bc454b2 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -64,6 +64,9 @@ void ConvBNAddReluKernel::Compute( case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index 9a1b9e199a89c4bc9fcd195d2069808e754d16fb..2f0387125f98b3b67aabc62582375fefa2a105b9 100644 --- a/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -77,6 +77,9 @@ void ConvBNReluKernel::Compute( case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp index e403f51357e2c46bf1e59be92c54778a5abfa595..c0906e23a39040c76729756f05bcad9b7bdd4b07 100644 --- a/src/operators/kernel/arm/convolution/conv_common.cpp +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "operators/kernel/arm/convolution/conv_common.h" +#include "framework/context.h" +#include "operators/math/gemm/gemm1x1s1.h" #include "operators/math/slidingwindow_utils.h" #include "operators/math/winograd/winograd_transform.h" @@ -20,6 +22,8 @@ namespace paddle_mobile { namespace operators { void InitBaseConvKernel(ConvParam *param) { + bool conv1x1 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 1; bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && param->Filter()->dims()[2] == 3; bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && @@ -83,6 +87,22 @@ void InitBaseConvKernel(ConvParam *param) { math::slidingwindow_transform_weight(*param->Filter(), param->transformed_filter_); param->ExecMode() = ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT; + } else if (conv1x1 && param->Groups() == 1 && + param->Paddings()[0] == param->Paddings()[1] && + param->Paddings()[0] == 0 && param->Input()->dims()[1] > 1 && + param->Strides()[0] == param->Strides()[1] && + param->Dilations()[0] == param->Dilations()[1] && + param->Strides()[0] == 1 && param->Dilations()[0] == 1 && + param->Output()->dims()[2] * param->Output()->dims()[3] > 1) { + // transform weight + Variable *transformed_var = param->GetScope()->Var(); + ARMArch arch = framework::CPUContext::Context()->get_arch(); + param->transformed_filter_ = + transformed_var->GetMutable(); + math::gemm1x1s1_transform_weight(*param->Filter(), *param->Output(), + param->transformed_filter_, + param->groups, arch); + param->ExecMode() = ConvParam::EXEC_GEMM1x1s1_FLOAT; } else { param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; } diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp index d73046489d7ee1547a9ad4390f0566bb028fac4c..7a3e8471310fef451e15afbe967b692bf15c87fa 100644 --- a/src/operators/kernel/arm/convolution/conv_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -54,6 +54,9 @@ void ConvKernel::Compute(const ConvParam ¶m) { case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/conv_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_relu_kernel.cpp index 58c00dabcb6891aed706a392177dab1272b46c00..c9c42639b7f70a35f27f70f77fdb5a38e955972d 100644 --- a/src/operators/kernel/arm/convolution/conv_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_relu_kernel.cpp @@ -45,6 +45,9 @@ void ConvReluKernel::Compute( case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; case ConvParam::EXEC_SLIDINGWINDOW3x3S1_FLOAT: case ConvParam::EXEC_SLIDINGWINDOW3x3S2_FLOAT: SlidingwindowConv3x3(param); diff --git a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp index fa3a424a5e4ec5a253679e1dd9f6a2eb9797b20d..7b5bf86038ccffe0a0c6922dba687754202ec7e3 100644 --- a/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp @@ -76,6 +76,9 @@ void DWConvBNReluKernel::Compute( case ConvParam::EXEC_GEMM_FLOAT: GemmConv(param); break; + case ConvParam::EXEC_GEMM1x1s1_FLOAT: + GemmConv1x1s1(param); + break; default: PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", param.ExecMode()); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.cpp b/src/operators/kernel/central-arm-func/conv_arm_func.cpp index 9cd7cff4a4a9ccca41df932513c68db2baf8c6b6..a12562a88392278d43210fa6c697afdb6e26d7f5 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.cpp +++ b/src/operators/kernel/central-arm-func/conv_arm_func.cpp @@ -14,9 +14,11 @@ limitations under the License. */ #include "operators/kernel/central-arm-func/conv_arm_func.h" #include +#include "framework/context.h" #include "operators/math/depthwise/faster_depthwise_conv3x3.h" #include "operators/math/depthwise_conv3x3.h" #include "operators/math/depthwise_conv5x5.h" +#include "operators/math/gemm/gemm1x1s1.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/pad.h" @@ -137,6 +139,61 @@ void GemmConv(const ConvParam ¶m) { } } +template +void GemmConv1x1s1(const ConvParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.transformed_filter_; + Tensor *output = param.Output(); + output->mutable_data(); + + const float *din = input->data(); + float *dout = output->mutable_data(); + const int num = input->dims()[0]; + const int chin = input->dims()[1]; + const int hin = input->dims()[2]; + const int win = input->dims()[3]; + const int chout = output->dims()[1]; + const int hout = output->dims()[2]; + const int wout = output->dims()[3]; + const float *weights = filter.mutable_data(); + const float *bias = nullptr; + + int channel_size_out = wout * hout; + int channel_size_in = win * hin; + const int group = param.Groups(); + const int m = chout / group; + const int n = hout * wout; + const int k = chin / group; + + bool flag_relu = false; + bool flag_bias = false; + ARMArch arch = framework::CPUContext::Context()->get_arch(); + int hblock = math::get_hblock(arch); + + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int weights_size_per_group = m * k; + if (n > 1) { + weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; + } + + for (int b = 0; b < num; ++b) { + // dC + for (int g = 0; g < group; ++g) { + float *dout_group = + static_cast(dout) + (b * chout + g * m) * channel_size_out; + const float *din_group = static_cast(din) + + (b * chin + g * k) * channel_size_in; + const float *weights_group = + static_cast(weights) + g * weights_size_per_group; + const float *bias_group = static_cast(bias) + g * m; + if (n > 1) { + math::sgemm_prepack(weights_group, din_group, bias_group, dout_group, m, + n, k, flag_bias, flag_relu, false, arch); + } + } + } +} + template void WinogradConv3x3(const ConvParam ¶m) { const Tensor *input = param.Input(); @@ -293,6 +350,7 @@ void SlidingwindowConv3x3(const ConvParam ¶m) { } template void GemmConv(const ConvParam ¶m); +template void GemmConv1x1s1(const ConvParam ¶m); template void WinogradConv3x3<8, 3>(const ConvParam ¶m); template void DepthwiseConv3x3(const ConvParam ¶m); template void DepthwiseConv5x5(const ConvParam ¶m); diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 8cd99aba4603e77ac95f60e90fd0cc28415837c6..f2c1070fa0f5e11f8f92cef7f8089ada00b73216 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -32,6 +32,9 @@ bool IsExpand(const std::vector &filter_dim, template void GemmConv(const ConvParam ¶m); +template +void GemmConv1x1s1(const ConvParam ¶m); + template void WinogradConv3x3(const ConvParam ¶m); diff --git a/src/operators/math/gemm/gemm1x1s1.cpp b/src/operators/math/gemm/gemm1x1s1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4beae5833b17df00fb78090fbfae9b02cf77d495 --- /dev/null +++ b/src/operators/math/gemm/gemm1x1s1.cpp @@ -0,0 +1,2198 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(__ARM_NEON) || defined(__ARM_NEON__) +#ifdef CONV_OP + +#include "operators/math/gemm/gemm1x1s1.h" +#include +#include "framework/context.h" +#include "iostream" + +namespace paddle_mobile { +namespace operators { +namespace math { + +#ifdef __aarch64__ +void prepackA_8x12(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t *dout = reinterpret_cast(out); + const uint32_t *inptr = reinterpret_cast(in); + int stride = x_len * 8; + +#pragma omp parallel for + for (int y = m0; y < mmax; y += 8) { + uint32_t *outptr = dout + stride * (y - m0) / 8; + + const uint32_t *inptr0 = inptr + y * ldin + k0; + const uint32_t *inptr1 = inptr0 + ldin; + const uint32_t *inptr2 = inptr1 + ldin; + const uint32_t *inptr3 = inptr2 + ldin; + const uint32_t *inptr4 = inptr3 + ldin; + const uint32_t *inptr5 = inptr4 + ldin; + const uint32_t *inptr6 = inptr5 + ldin; + const uint32_t *inptr7 = inptr6 + ldin; + + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + "prfm pldl1keep, [%[ptr4]] \n" + "prfm pldl1keep, [%[ptr4], #64] \n" + "prfm pldl1keep, [%[ptr5]] \n" + "prfm pldl1keep, [%[ptr5], #64] \n" + "prfm pldl1keep, [%[ptr6]] \n" + "prfm pldl1keep, [%[ptr6], #64] \n" + "prfm pldl1keep, [%[ptr7]] \n" + "prfm pldl1keep, [%[ptr7], #64] \n" + : + : [ptr0] "r"(inptr0), [ptr1] "r"(inptr1), [ptr2] "r"(inptr2), + [ptr3] "r"(inptr3), [ptr4] "r"(inptr4), [ptr5] "r"(inptr5), + [ptr6] "r"(inptr6), [ptr7] "r"(inptr7) + : "memory"); + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 7) >= mmax) { + switch ((y + 7) - mmax) { + case 6: + inptr1 = zerobuff; + case 5: + inptr2 = zerobuff; + case 4: + inptr3 = zerobuff; + case 3: + inptr4 = zerobuff; + case 2: + inptr5 = zerobuff; + case 1: + inptr6 = zerobuff; + case 0: + inptr7 = zerobuff; + default: + break; + } + } + for (; x > 7; x -= 8) { + asm volatile( + // Load up 8 elements (2 vectors) from each of 8 sources. + "LDP q0, q1, [%[inptr0]], #32\n" // q0=A0A1A2A3 + "LDP q2, q3, [%[inptr1]], #32\n" // q2=B0B1B2B3 + "LDP q4, q5, [%[inptr2]], #32\n" // q4=C0C1C2C3 + "ZIP1 v16.4s, v0.4s, v4.4s\n" // q16=A0C0A1C1 + "prfm pldl1keep, [%[inptr0], #128] \n" + "LDP q6, q7, [%[inptr3]], #32\n" // q6=D0D1D2D3 + "ZIP1 v17.4s, v2.4s, v6.4s\n" // q17=B0D0B1D1 + "LDP q8, q9, [%[inptr4]], #32\n" + "LDP q10, q11, [%[inptr5]], #32\n" + "LDP q12, q13, [%[inptr6]], #32\n" + "ZIP1 v18.4s, v8.4s, v12.4s\n" + "prfm pldl1keep, [%[inptr1], #128]\n" + "LDP q14, q15, [%[inptr7]], #32\n" + "ZIP1 v19.4s, v10.4s, v14.4s\n" + + "ZIP1 v20.4s, v16.4s, v17.4s\n" // q20=A0B0C0D0 + "prfm pldl1keep, [%[inptr2], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v0.4s, v4.4s\n" + "prfm pldl1keep, [%[inptr3], #128]\n" + "ZIP2 v17.4s, v2.4s, v6.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Write back the first + // element of each source + + "ZIP2 v18.4s, v8.4s, v12.4s\n" + "ZIP2 v19.4s, v10.4s, v14.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Write back the second + // element of each source + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr4], #128]\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP1 v16.4s, v1.4s, v5.4s\n" + "prfm pldl1keep, [%[inptr5], #128]\n" + "ZIP1 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Third element + + "ZIP1 v18.4s, v9.4s, v13.4s\n" + "ZIP1 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Fourth element + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "prfm pldl1keep, [%[inptr6], #128]\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + + "ZIP2 v16.4s, v1.4s, v5.4s\n" + "ZIP2 v17.4s, v3.4s, v7.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Fifth element + + "ZIP2 v18.4s, v9.4s, v13.4s\n" + "prfm pldl1keep, [%[inptr7], #128]\n" + "ZIP2 v19.4s, v11.4s, v15.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Sixth element + + "ZIP1 v20.4s, v16.4s, v17.4s\n" + "ZIP1 v21.4s, v18.4s, v19.4s\n" + "STP q20, q21, [%[outptr]], #32\n" // Seventh element + + "ZIP2 v22.4s, v16.4s, v17.4s\n" + "ZIP2 v23.4s, v18.4s, v19.4s\n" + "STP q22, q23, [%[outptr]], #32\n" // Eighth element + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [inptr6] "+r"(inptr6), [inptr7] "+r"(inptr7), [outptr] "+r"(outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +#else //__aarch64__ +void prepackA_6x8(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* dout = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + uint32_t* outptr = dout; + + //! data A is not transposed, transpose A to k * 6 + for (int y = m0; y < mmax; y += 6) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + const uint32_t* inptr4 = inptr3 + ldin; + const uint32_t* inptr5 = inptr4 + ldin; + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 5) >= mmax) { + switch ((y + 5) - mmax) { + case 4: + inptr1 = zerobuff; + case 3: + inptr2 = zerobuff; + case 2: + inptr3 = zerobuff; + case 1: + inptr4 = zerobuff; + case 0: + inptr5 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 6 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + "vld4.32 {d16-d19}, [%[inptr4]]! @ zip load r4, " + "q8,q9=r40,r44,r41,r45,r42,r46,r43,r47\n" + "vld4.32 {d20-d23}, [%[inptr5]]! @ zip load r5, " + "q10,q11=r50,r54,r51,r55,r52,r56,r53,r57\n" + + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + "vtrn.32 q8, q10 @ trans data: q8=r40,r50,r41,r51; " + "q10=r44,r54,r45,r55\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d16}, [%[outptr]]! @ write d16(q8,low),r40,r50\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + "vst1.32 {d17}, [%[outptr]]! @ write d16(q8,high),r41,r51\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + "vtrn.32 q9, q11 @ trans data: q9=r42,r52,r43,r53; " + "q11=r46,r56,r47,r57\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d18}, [%[outptr]]! @ write d18(q9,low),r42,r52\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + "vst1.32 {d19}, [%[outptr]]! @ write d19(q9,high),r43,r53\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d20}, [%[outptr]]! @ write d20(q10,low),r44,r54\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + "vst1.32 {d21}, [%[outptr]]! @ write d21(q10,high),r45,r55\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d22}, [%[outptr]]! @ write d22(q11,low),r46,r56\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + "vst1.32 {d23}, [%[outptr]]! @ write d23(q11,high),r47,r57\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [inptr4] "+r"(inptr4), [inptr5] "+r"(inptr5), + [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + } + } +} + +void prepackA_4x8(float* out, const float* in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax) { + int x_len = kmax - k0; + uint32_t zerobuff[x_len]; + memset(zerobuff, 0, sizeof(uint32_t) * x_len); + + uint32_t* dout = reinterpret_cast(out); + const uint32_t* inptr = reinterpret_cast(in); + + uint32_t* outptr = dout; + //! data A is not transposed, transpose A to k * 4 + for (int y = m0; y < mmax; y += 4) { + const uint32_t* inptr0 = inptr + y * ldin + k0; + const uint32_t* inptr1 = inptr0 + ldin; + const uint32_t* inptr2 = inptr1 + ldin; + const uint32_t* inptr3 = inptr2 + ldin; + + int x = x_len; + //! cope with row index exceed real size, set to zero buffer + if ((y + 3) >= mmax) { + switch ((y + 3) - mmax) { + case 2: + inptr1 = zerobuff; + case 1: + inptr2 = zerobuff; + case 0: + inptr3 = zerobuff; + default: + break; + } + } + + for (; x > 7; x -= 8) { + //! zip load 8 elements (2 neon Q registers) from each of 4 rows + asm volatile( + "vld4.32 {d0-d3}, [%[inptr0]]! @ zip load r0, " + "q0,q1=r00,r04,r01,r05,r02,r06,r03,r07\n" + "vld4.32 {d4-d7}, [%[inptr1]]! @ zip load r1, " + "q2,q3=r10,r14,r11,r15,r12,r16,r13,r17\n" + "vld4.32 {d8-d11}, [%[inptr2]]! @ zip load r2, " + "q4,q5=r20,r24,r21,r25,r22,r26,r23,r27\n" + "vld4.32 {d12-d15}, [%[inptr3]]! @ zip load r3, " + "q6,q7=r30,r34,r31,r35,r32,r36,r33,r37\n" + + "vtrn.32 q0, q2 @ trans data: q0=r00,r10,r01,r11; " + "q2=r04,r14,r05,r15\n" + "vtrn.32 q4, q6 @ trans data: q4=r20,r30,r21,r31; " + "q6=r24,r34,r25,r35\n" + + "vswp d1, d8 @ swap d1, d8, q0=r00,r10,r20,r30; " + "q4=r01,r11,r21,r31\n" + "vst1.32 {d0-d1}, [%[outptr]]! @ write q0:r00,r10,r20,r30\n" + "vst1.32 {d8-d9}, [%[outptr]]! @ write q4:r01,r11,r21,r31\n" + + "vtrn.32 q1, q3 @ trans data: q1=r02,r12,r03,r13; " + "q3=r06,r16,r07,r17\n" + "vtrn.32 q5, q7 @ trans data: q5=r22,r32,r23,r33; " + "q7=r26,r36,r27,r37\n" + + "vswp d3, d10 @ swap d3, d10, " + "q1=r02,r12,r22,r32; q5=r03,r13,r23,r33\n" + "vst1.32 {d2-d3}, [%[outptr]]! @ write q1:r02,r12,r22,r32\n" + "vst1.32 {d10-d11},[%[outptr]]! @ write q5:r03,r13,r23,r33\n" + + "vswp d5, d12 @ swap d5, d12,q2=r04,r14,r24,r34; " + "q6=r05,r15,r25,r35\n" + "vst1.32 {d4-d5}, [%[outptr]]! @ write q2:r04,r14,r24,r34\n" + "vst1.32 {d12-d13},[%[outptr]]! @ write q6:r05,r15,r25,r35\n" + + "vswp d7, d14 @ swap d7, d14, " + "q3=r06,r16,r26,r36; q7=r07,r17,r27,r37\n" + "vst1.32 {d6-d7}, [%[outptr]]! @ write q3:r06,r16,r26,r36\n" + "vst1.32 {d14-d15},[%[outptr]]! @ write q7:r07,r17,r27,r37\n" + : [inptr0] "+r"(inptr0), [inptr1] "+r"(inptr1), [inptr2] "+r"(inptr2), + [inptr3] "+r"(inptr3), [outptr] "+r"(outptr) + : + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "cc", "memory"); + } + + for (; x > 0; x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + } + } +} +#endif //__aarch64__ + +void prepackA(float *out, const float *in, const int ldin, const int m0, + const int mmax, const int k0, const int kmax, bool is_trans, + ARMArch arch) { +#ifdef __aarch64__ + if (!is_trans) { + prepackA_8x12(out, in, ldin, m0, mmax, k0, kmax); + } +#else + if (arch == A73) { + if (!is_trans) { + prepackA_4x8(out, in, ldin, m0, mmax, k0, kmax); + } + } else { + if (!is_trans) { + prepackA_6x8(out, in, ldin, m0, mmax, k0, kmax); + } + } +#endif +} + +void gemm1x1s1_transform_weight(const framework::Tensor &weight, + const framework::Tensor &output, + framework::Tensor *trans_weight, + const int group, ARMArch arch) { + const int chout = weight.dims()[0]; + const int chin = weight.dims()[1]; + const int hout = output.dims()[2]; + const int wout = output.dims()[3]; + const int m = chout / group; + const int n = hout * wout; + const int k = chin / group; + + if (n > 1) { + int hblock = get_hblock(arch); + int m_roundup = hblock * ((m + hblock - 1) / hblock); + int weights_size_per_group = ((m_roundup * k + 15) / 16) * 16; + int weight_worksize = sizeof(float) * weights_size_per_group * group; + float *w_trans_ptr = trans_weight->mutable_data({weight_worksize}); + for (int g = 0; g < group; ++g) { + const float *weights_group = weight.data() + g * m * k; + float *weights_trans_ptr = w_trans_ptr + g * weights_size_per_group; + prepackA(weights_trans_ptr, weights_group, k, 0, m, 0, k, false, arch); + } + } +} + +#ifdef __aarch64__ +void loadb(float *out, const float *in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + uint32_t *outptr = reinterpret_cast(out); + const uint32_t *inptr = + reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 12 * (x_len / 12); + int right_pad = 12 - right_remain; + const size_t copy_len_remain = sizeof(float) * right_remain; + const size_t copy_len_pad = sizeof(float) * right_pad; + const size_t size_ldin = sizeof(float) * ldin; + + uint32_t *outptr_row = outptr; + int stride_out = 12 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + uint32x4_t vmask3 = + vcltq_u32(vld1q_u32(mask_buffer + 8), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t *ptr0 = inptr + y * ldin; + const uint32_t *ptr1 = ptr0 + ldin; + const uint32_t *ptr2 = ptr1 + ldin; + const uint32_t *ptr3 = ptr2 + ldin; + asm volatile( + "prfm pldl1keep, [%[ptr0]] \n" + "prfm pldl1keep, [%[ptr0], #64] \n" + "prfm pldl1keep, [%[ptr1]] \n" + "prfm pldl1keep, [%[ptr1], #64] \n" + "prfm pldl1keep, [%[ptr2]] \n" + "prfm pldl1keep, [%[ptr2], #64] \n" + "prfm pldl1keep, [%[ptr3]] \n" + "prfm pldl1keep, [%[ptr3], #64] \n" + : + : [ptr0] "r"(ptr0), [ptr1] "r"(ptr1), [ptr2] "r"(ptr2), [ptr3] "r"(ptr3) + : "memory"); + + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + vst1q_u32(outptr_row_col, vr00); + vst1q_u32(outptr_row_col + 4, vr01); + vst1q_u32(outptr_row_col + 8, vr02); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col + 12, vr10); + vst1q_u32(outptr_row_col + 16, vr11); + vst1q_u32(outptr_row_col + 20, vr12); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 24, vr20); + vst1q_u32(outptr_row_col + 28, vr21); + vst1q_u32(outptr_row_col + 32, vr22); + + vst1q_u32(outptr_row_col + 36, vr30); + vst1q_u32(outptr_row_col + 40, vr31); + vst1q_u32(outptr_row_col + 44, vr32); + + ptr0 += 12; + ptr1 += 12; + ptr2 += 12; + ptr3 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr00 = vld1q_u32(ptr0); + uint32x4_t vr01 = vld1q_u32(ptr0 + 4); + uint32x4_t vr02 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr10 = vld1q_u32(ptr1); + uint32x4_t vr11 = vld1q_u32(ptr1 + 4); + uint32x4_t vr12 = vld1q_u32(ptr1 + 8); + + uint32x4_t vr00_1 = vbslq_u32(vmask1, vr00, vzero); + uint32x4_t vr01_1 = vbslq_u32(vmask2, vr01, vzero); + uint32x4_t vr02_1 = vbslq_u32(vmask3, vr02, vzero); + + uint32x4_t vr20 = vld1q_u32(ptr2); + uint32x4_t vr21 = vld1q_u32(ptr2 + 4); + uint32x4_t vr22 = vld1q_u32(ptr2 + 8); + + vst1q_u32(outptr_row_col, vr00_1); + vst1q_u32(outptr_row_col + 4, vr01_1); + vst1q_u32(outptr_row_col + 8, vr02_1); + + uint32x4_t vr10_1 = vbslq_u32(vmask1, vr10, vzero); + uint32x4_t vr11_1 = vbslq_u32(vmask2, vr11, vzero); + uint32x4_t vr12_1 = vbslq_u32(vmask3, vr12, vzero); + + uint32x4_t vr30 = vld1q_u32(ptr3); + uint32x4_t vr31 = vld1q_u32(ptr3 + 4); + uint32x4_t vr32 = vld1q_u32(ptr3 + 8); + + vst1q_u32(outptr_row_col + 12, vr10_1); + vst1q_u32(outptr_row_col + 16, vr11_1); + vst1q_u32(outptr_row_col + 20, vr12_1); + + uint32x4_t vr20_1 = vbslq_u32(vmask1, vr20, vzero); + uint32x4_t vr21_1 = vbslq_u32(vmask2, vr21, vzero); + uint32x4_t vr22_1 = vbslq_u32(vmask3, vr22, vzero); + + uint32x4_t vr30_1 = vbslq_u32(vmask1, vr30, vzero); + uint32x4_t vr31_1 = vbslq_u32(vmask2, vr31, vzero); + uint32x4_t vr32_1 = vbslq_u32(vmask3, vr32, vzero); + + vst1q_u32(outptr_row_col + 24, vr20_1); + vst1q_u32(outptr_row_col + 28, vr21_1); + vst1q_u32(outptr_row_col + 32, vr22_1); + + vst1q_u32(outptr_row_col + 36, vr30_1); + vst1q_u32(outptr_row_col + 40, vr31_1); + vst1q_u32(outptr_row_col + 44, vr32_1); + } + } + +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t *ptr0 = inptr + y * ldin; + uint32_t *outptr_row_col = outptr_row + y * 12; + + int i = 0; + for (; i < x_len - 11; i += 12) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + vst1q_u32(outptr_row_col, vr0); + vst1q_u32(outptr_row_col + 4, vr1); + vst1q_u32(outptr_row_col + 8, vr2); + + ptr0 += 12; + + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32x4_t vr0 = vld1q_u32(ptr0); + uint32x4_t vr1 = vld1q_u32(ptr0 + 4); + uint32x4_t vr2 = vld1q_u32(ptr0 + 8); + + uint32x4_t vr0_1 = vbslq_u32(vmask1, vr0, vzero); + uint32x4_t vr1_1 = vbslq_u32(vmask2, vr1, vzero); + uint32x4_t vr2_1 = vbslq_u32(vmask3, vr2, vzero); + + vst1q_u32(outptr_row_col, vr0_1); + vst1q_u32(outptr_row_col + 4, vr1_1); + vst1q_u32(outptr_row_col + 8, vr2_1); + } + } +} +#else //__aarch64__ +void loadb(float* out, const float* in, const int ldin, const int k0, + const int kmax, const int n0, const int nmax) { + uint32_t* outptr = reinterpret_cast(out); + const uint32_t* inptr = + reinterpret_cast(in) + k0 * ldin + n0; + uint32_t mask_buffer[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + int x_len = nmax - n0; + int y_len = kmax - k0; + int right_remain = x_len - 8 * (x_len / 8); + int right_pad = 8 - right_remain; + const size_t copy_len_remain = sizeof(float) * right_remain; + const size_t copy_len_pad = sizeof(float) * right_pad; + const size_t size_ldin = sizeof(float) * ldin; + + uint32_t* outptr_row = outptr; + int stride_out = 8 * y_len; + + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = + vcltq_u32(vld1q_u32(mask_buffer), vdupq_n_u32(right_remain)); + uint32x4_t vmask2 = + vcltq_u32(vld1q_u32(mask_buffer + 4), vdupq_n_u32(right_remain)); + +#pragma omp parallel for + for (int y = 0; y < y_len - 3; y += 4) { + const uint32_t* ptr0 = inptr + y * ldin; + const uint32_t* ptr1 = ptr0 + ldin; + const uint32_t* ptr2 = ptr1 + ldin; + const uint32_t* ptr3 = ptr2 + ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : + : "q0", "q1", "q2", "q3", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr1]]! @ load r1, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + + "vld1.32 {d0-d3}, [%[ptr2]]! @ load r2, 8 elements\n" + "vld1.32 {d4-d7}, [%[ptr3]]! @ load r3, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + //"vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vbif q2, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q3, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + "vst1.32 {d4-d7}, [%[outptr]]! @ write to output ptr\n" + : [outptr] "+r"(ptr_out), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), + [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "q2", "q3", "cc", "memory"); + } + } +#pragma omp parallel for + for (int y = 4 * (y_len / 4); y < y_len; ++y) { + const uint32_t* ptr0 = inptr + y * ldin; + uint32_t* outptr_row_col = outptr_row + y * 8; + int i = 0; + for (; i < x_len - 7; i += 8) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : + : "q0", "q1", "cc", "memory"); + outptr_row_col += stride_out; + } + if (right_remain > 0) { + uint32_t* ptr_out = outptr_row_col; + asm volatile( + "vld1.32 {d0-d3}, [%[ptr0]]! @ load r0, 8 elements\n" + "vbif q0, %q[vzero], %q[vmask1] @ bit select, pad zero\n" + "vbif q1, %q[vzero], %q[vmask2] @ bit select, pad zero\n" + "vst1.32 {d0-d3}, [%[outptr]]! @ write to output ptr\n" + : [ptr0] "+r"(ptr0), [outptr] "+r"(ptr_out) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero) + : "q0", "q1", "cc", "memory"); + } + } +} +#endif //__aarch64__ + +#ifdef __aarch64__ +void sgemm_conv_8x12(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB) { + const int threads = framework::CPUContext::Context()->get_thread_num(); + int l2_size = + framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float); + int l2_cache = l2_size > 0 ? l2_size : 512 * 1024; + + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = (l2_cache - (MBLOCK * K)) / (sizeof(float) * (K + MBLOCK)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + // unroll 2 loop + int tail_pre = (K & (KBLOCK - 1)); + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float *b_pannel = + static_cast(framework::CPUContext::Context()->get_work_space( + K * (xmax - x0) * sizeof(float))); + if (!transB) { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK) { + unsigned int ymax = y + MBLOCK; + if (ymax > M) { + ymax = M; + } + + float bias_local[8] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + bias_local[6] = bias[y + 6]; + bias_local[7] = bias[y + 7]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + float cout6[NBLOCK]; + float cout7[NBLOCK]; + + float *c_ptr0 = C + y * N + x0; + float *c_ptr1 = c_ptr0 + N; + float *c_ptr2 = c_ptr1 + N; + float *c_ptr3 = c_ptr2 + N; + float *c_ptr4 = c_ptr3 + N; + float *c_ptr5 = c_ptr4 + N; + float *c_ptr6 = c_ptr5 + N; + float *c_ptr7 = c_ptr6 + N; + + float *pout0 = c_ptr0; + float *pout1 = c_ptr1; + float *pout2 = c_ptr2; + float *pout3 = c_ptr3; + float *pout4 = c_ptr4; + float *pout5 = c_ptr5; + float *pout6 = c_ptr6; + float *pout7 = c_ptr7; + + const float *a_ptr_l = A_packed + y * K; + const float *b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + c_ptr1 = cout1; + case 5: + c_ptr2 = cout2; + case 4: + c_ptr3 = cout3; + case 3: + c_ptr4 = cout4; + case 2: + c_ptr5 = cout5; + case 1: + c_ptr6 = cout6; + case 0: + c_ptr7 = cout7; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + pout6 = c_ptr6; + pout7 = c_ptr7; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + c_ptr6 = cout6; + c_ptr7 = cout7; + } + const float *a_ptr = a_ptr_l; + int tail = tail_pre; + int k = k_pre; + + asm volatile( + // Initialize result registers, load initial operands, prime + // prefetches. + "ldp q2, q3, [%[bias_ptr]]\n" /* load bias to q2, q3*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00,a01 to q0, q1*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "dup v8.4s, v2.s[0]\n" /* out0 = 0 */ + "dup v9.4s, v2.s[0]\n" /* out1 = 0*/ + "dup v10.4s, v2.s[0]\n" /* out2 = 0*/ + "dup v11.4s, v2.s[1]\n" /* out3 = 0*/ + "dup v12.4s, v2.s[1]\n" /* out4 = 0*/ + "prfm pldl1keep, [%[b_ptr], #64]\n" /* preload b*/ + "dup v13.4s, v2.s[1]\n" /* out5 = 0*/ + "prfm pldl1keep, [%[a_ptr], #64]\n" /* preload a*/ + "dup v14.4s, v2.s[2]\n" /* out6 = 0*/ + "prfm pldl1keep, [%[b_ptr], #128]\n" /* preload b*/ + "dup v15.4s, v2.s[2]\n" /* out7 = 0*/ + "prfm pldl1keep, [%[a_ptr], #128]\n" /* preload a*/ + "dup v16.4s, v2.s[2]\n" /* out8 = 0*/ + "prfm pldl1keep, [%[b_ptr], #192]\n" /* preload b*/ + "dup v17.4s, v2.s[3]\n" /* out9 = 0*/ + "prfm pldl1keep, [%[b_ptr], #256]\n" /* preload b*/ + "dup v18.4s, v2.s[3]\n" /* out10 = 0*/ + "prfm pldl1keep, [%[a_ptr], #192]\n" /* preload a*/ + "dup v19.4s, v2.s[3]\n" /* out11 = 0*/ + "prfm pldl1keep, [%[b_ptr], #320]\n" /* preload b*/ + "dup v20.4s, v3.s[0]\n" /* out12 = 0*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" /* preload a*/ + "dup v21.4s, v3.s[0]\n" /* out13 = 0*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" /* preload b*/ + "dup v22.4s, v3.s[0]\n" /* out14 = 0*/ + "dup v23.4s, v3.s[1]\n" /* out15 = 0*/ + "dup v24.4s, v3.s[1]\n" /* out16 = 0*/ + "dup v25.4s, v3.s[1]\n" /* out17 = 0*/ + "dup v26.4s, v3.s[2]\n" /* out18 = 0*/ + "dup v27.4s, v3.s[2]\n" /* out19 = 0*/ + "dup v28.4s, v3.s[2]\n" /* out20 = 0*/ + "dup v29.4s, v3.s[3]\n" /* out21 = 0*/ + "dup v30.4s, v3.s[3]\n" /* out22 = 0*/ + "dup v31.4s, v3.s[3]\n" /* out23 = 0*/ + "cbz %w[k], 2f\n" /* check loop count > 0 */ + /* main loop */ + /* unrool 0*/ + "1:\n" /* main loop */ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = q4 + */ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = q4 + */ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7 */ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = q4 + */ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = q4 + */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4 */ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = q4 + */ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = q4 + */ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = q4 + */ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = q4 + */ + + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = q5 */ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = q5 + */ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q5*/ + + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5 */ + + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q6*/ + + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1 */ + + /* unrool 1 */ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = q7 + */ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = q7 + */ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = q7 + */ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = q7 + */ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = q7 + */ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7 + */ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = q7 + */ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = q7 + */ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7 */ + + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = q4 */ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = q4 + */ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q4*/ + + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5 */ + /* unrool 2*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = q6 + */ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = q6 + */ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q7*/ + + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q4*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + /* unrool 3*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b0, b1 to q4, q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = q6*/ + "prfm pldl1keep, [%[a_ptr], #256]\n" + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "prfm pldl1keep, [%[b_ptr], #384]\n" + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "subs %w[k], %w[k], #1\n" /* loop count - 1*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "bne 1b\n" + /* Target to use when K is 1 or 2 (i.e. zero iterations of main + loop)*/ + "2:\n" /* process tail*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "beq 3f\n" /*jump to tail = 1*/ + /* final unrool 0*/ + /* unrool 0, tail > 1*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = q4*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q4*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b2, b0 to q6, q7*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = q4*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = q4*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q2, q3*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = q4*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = q4*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = q4*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = q5*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = q5*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q5*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q5*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q5*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q5*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q5*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q5*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b1, b2 to q4, q5*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q6*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q6*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q6*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q6*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q6*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q6*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q6*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q6*/ + "beq 4f\n" /*jump to tail = 2*/ + /* unrool 1, tail > 2*/ + "ldp q0, q1, [%[a_ptr]], #32\n" /* load a00, a01 to q0, q1*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = q7*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q7*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = q7*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = q7*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = q7*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q7*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = q7*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b0, b1 to q6, q7*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = q4*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b0 * a10[1], b1 = q4*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q4*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q4*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q4*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q4*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q4*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q4*/ + "subs %w[tail], %w[tail], #1\n" /* tail--*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q5*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q5*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q5*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q5*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q5*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q5*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q5*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q5*/ + "beq 5f\n" /*jump to tail = 3*/ + /* unrool 2, tail = 4*/ + "ldp q4, q5, [%[b_ptr]], #32\n" /* load b2, b0 to q4, q5*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a00[0], b0 = q6*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a00[1], b0 = + q6*/ + "ldp q2, q3, [%[a_ptr]], #32\n" /* load a10, a11 to q3, q4*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a00[2], b0 = q6*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a00[3], b0 = q6*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a01[0], b0 = q6*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a01[1], b0 = q6*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a01[2], b0 = q6*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a01[3], b0 = q6*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b1 * a00[0], b1 = q7*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a00[1], b1 = q7*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a00[2], b1 = + q7*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a00[3], b1 = + q7*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a01[0], b1 = + q7*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a01[1], b1 = + q7*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a01[2], b1 = + q7*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a01[3], b1 = + q7*/ + "ldp q6, q7, [%[b_ptr]], #32\n" /* load b1, b2 to q6, q7*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a00[0], b2 = + q4*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a00[1], b2 = + q4*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a00[2], b2 = + q4*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a00[3], b2 = + q4*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a00[0], b2 = + q4*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a00[1], b2 = + q4*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a00[2], b2 = + q4*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a00[3], b2 = + q4*/ + /* unrool 3, tail = 4*/ + "fmla v8.4s , v5.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ + "fmla v11.4s , v5.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v5.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ + "fmla v17.4s, v5.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ + "fmla v20.4s, v5.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ + "fmla v23.4s, v5.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v5.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ + "fmla v29.4s, v5.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ + "fmla v9.4s, v6.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ + "fmla v12.4s, v6.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ + "fmla v15.4s, v6.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v6.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v6.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v6.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v6.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v6.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v7.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v7.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v7.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v7.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v7.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v7.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v7.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v7.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==1 final tail*/ + "3: \n" /* tail=1*/ + "ldr q6, [%[b_ptr]], #16\n" /* load b2 to q6*/ + "fmla v8.4s , v4.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ + "fmla v11.4s , v4.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v4.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ + "fmla v17.4s, v4.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ + "fmla v20.4s, v4.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ + "fmla v23.4s, v4.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v4.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ + "fmla v29.4s, v4.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ + "fmla v9.4s, v5.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ + "fmla v12.4s, v5.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ + "fmla v15.4s, v5.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v5.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v5.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v5.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v5.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v5.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v6.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v6.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v6.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v6.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v6.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v6.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v6.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v6.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==2 final tail*/ + "4:\n" /* tail = 2*/ + "fmla v8.4s , v7.4s, v2.s[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ + "fmla v11.4s , v7.4s, v2.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v7.4s, v2.s[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ + "fmla v17.4s, v7.4s, v2.s[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ + "fmla v20.4s, v7.4s, v3.s[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ + "fmla v23.4s, v7.4s, v3.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v7.4s, v3.s[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ + "fmla v29.4s, v7.4s, v3.s[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ + "fmla v9.4s, v4.4s, v2.s[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ + "fmla v12.4s, v4.4s, v2.s[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ + "fmla v15.4s, v4.4s, v2.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v4.4s, v2.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v4.4s, v3.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v4.4s, v3.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v4.4s, v3.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v4.4s, v3.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v5.4s, v2.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v5.4s, v2.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v5.4s, v2.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v5.4s, v2.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v5.4s, v3.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v5.4s, v3.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v5.4s, v3.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v5.4s, v3.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "b 11f\n" + /* tails==3 final tail*/ + "5:\n" /* tail = 3*/ + "ldr q4, [%[b_ptr]], #16\n" /* load b2, b0 to q4*/ + "fmla v8.4s , v6.4s, v0.s[0]\n" /* out0 = b0 * a10[0], b0 = q5*/ + "fmla v11.4s , v6.4s, v0.s[1]\n" /* out1 = b0 * a10[1], b0 = + q5*/ + "fmla v14.4s, v6.4s, v0.s[2]\n" /* out2 = b0 * a10[2], b0 = q5*/ + "fmla v17.4s, v6.4s, v0.s[3]\n" /* out3 = b0 * a10[3], b0 = q5*/ + "fmla v20.4s, v6.4s, v1.s[0]\n" /* out4 = b0 * a11[0], b0 = q5*/ + "fmla v23.4s, v6.4s, v1.s[1]\n" /* out5 = b0 * a11[1], b0 = q5*/ + "fmla v26.4s, v6.4s, v1.s[2]\n" /* out6 = b0 * a11[2], b0 = q5*/ + "fmla v29.4s, v6.4s, v1.s[3]\n" /* out7 = b0 * a11[3], b0 = q5*/ + "fmla v9.4s, v7.4s, v0.s[0]\n" /* out8 = b0 * a10[0], b1 = q6*/ + "fmla v12.4s, v7.4s, v0.s[1]\n" /* out9 = b1 * a10[1], b1 = q6*/ + "fmla v15.4s, v7.4s, v0.s[2]\n" /* out10 = b1 * a10[2], b1 = + q6*/ + "fmla v18.4s, v7.4s, v0.s[3]\n" /* out11 = b1 * a10[3], b1 = + q6*/ + "fmla v21.4s, v7.4s, v1.s[0]\n" /* out12 = b1 * a10[0], b1 = + q6*/ + "fmla v24.4s, v7.4s, v1.s[1]\n" /* out13 = b1 * a10[1], b1 = + q6*/ + "fmla v27.4s, v7.4s, v1.s[2]\n" /* out14 = b1 * a10[2], b1 = + q6*/ + "fmla v30.4s, v7.4s, v1.s[3]\n" /* out15 = b1 * a10[3], b1 = + q6*/ + "fmla v10.4s, v4.4s, v0.s[0]\n" /* out16 = b2 * a10[0], b2 = + q7*/ + "fmla v13.4s, v4.4s, v0.s[1]\n" /* out17 = b2 * a10[0], b2 = + q7*/ + "fmla v16.4s, v4.4s, v0.s[2]\n" /* out18 = b2 * a10[0], b2 = + q7*/ + "fmla v19.4s, v4.4s, v0.s[3]\n" /* out19 = b2 * a10[0], b2 = + q7*/ + "fmla v22.4s, v4.4s, v1.s[0]\n" /* out20 = b2 * a10[0], b2 = + q7*/ + "fmla v25.4s, v4.4s, v1.s[1]\n" /* out21 = b2 * a10[0], b2 = + q7*/ + "fmla v28.4s, v4.4s, v1.s[2]\n" /* out22 = b2 * a10[0], b2 = + q7*/ + "fmla v31.4s, v4.4s, v1.s[3]\n" /* out23 = b2 * a10[0], b2 = + q7*/ + "11: \n" /* check if relu */ + "cbz %w[relu], 12f\n" /* skip relu */ + "movi v2.4s, #0\n" /* for relu*/ + "fmax v8.4s, v8.4s, v2.4s\n" /* relu*/ + "fmax v9.4s, v9.4s, v2.4s\n" /* relu*/ + "fmax v10.4s, v10.4s, v2.4s\n" /* relu*/ + "fmax v11.4s, v11.4s, v2.4s\n" /* relu*/ + "fmax v12.4s, v12.4s, v2.4s\n" /* relu*/ + "fmax v13.4s, v13.4s, v2.4s\n" /* relu*/ + "fmax v14.4s, v14.4s, v2.4s\n" /* relu*/ + "fmax v15.4s, v15.4s, v2.4s\n" /* relu*/ + "fmax v16.4s,v16.4s,v2.4s\n" /* relu*/ + "fmax v17.4s,v17.4s,v2.4s\n" /* relu*/ + "fmax v18.4s, v18.4s, v2.4s\n" /* relu*/ + "fmax v19.4s, v19.4s, v2.4s\n" /* relu*/ + "fmax v20.4s, v20.4s, v2.4s\n" /* relu*/ + "fmax v21.4s, v21.4s, v2.4s\n" /* relu*/ + "fmax v22.4s, v22.4s, v2.4s\n" /* relu*/ + "fmax v23.4s, v23.4s, v2.4s\n" /* relu*/ + "fmax v24.4s,v24.4s,v2.4s\n" /* relu*/ + "fmax v25.4s,v25.4s,v2.4s\n" /* relu*/ + "fmax v26.4s, v26.4s, v2.4s\n" /* relu*/ + "fmax v27.4s, v27.4s, v2.4s\n" /* relu*/ + "fmax v28.4s, v28.4s, v2.4s\n" /* relu*/ + "fmax v29.4s, v29.4s, v2.4s\n" /* relu*/ + "fmax v30.4s, v30.4s, v2.4s\n" /* relu*/ + "fmax v31.4s, v31.4s, v2.4s\n" /* relu*/ + "12: \n" + "st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n" /* store r0 */ + "st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n" /* store r1 */ + "st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n" /* store r2 */ + "st1 {v17.4s, v18.4s, v19.4s},[%[c_ptr3]], #48\n" /* store r3 */ + "st1 {v20.4s, v21.4s, v22.4s},[%[c_ptr4]], #48\n" /* store r4 */ + "st1 {v23.4s, v24.4s, v25.4s},[%[c_ptr5]], #48\n" /* store r5 */ + "st1 {v26.4s, v27.4s, v28.4s},[%[c_ptr6]], #48\n" /* store r6 */ + "st1 {v29.4s, v30.4s, v31.4s},[%[c_ptr7]], #48\n" /* store r7 */ + + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), + [tail] "+r"(tail), [c_ptr0] "+r"(c_ptr0), [c_ptr1] "+r"(c_ptr1), + [c_ptr2] "+r"(c_ptr2), [c_ptr3] "+r"(c_ptr3), + [c_ptr4] "+r"(c_ptr4), [c_ptr5] "+r"(c_ptr5), + [c_ptr6] "+r"(c_ptr6), [c_ptr7] "+r"(c_ptr7) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + *pout6++ = cout6[i]; + *pout7++ = cout7[i]; + } + } + } + } + } +} +#else //__aarch64__ +/** + * \brief gemm with ablock = 6, bblock = 8, output 6x8 + * @param A + * @param B + * @param C + * @param M + * @param N + * @param K + * @param threads + * @param workspace + */ +void sgemm_conv_6x8(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB) { + const int threads = framework::CPUContext::Context()->get_thread_num(); + int l2_size = + framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float); + int l2_cache = l2_size > 0 ? l2_size : 512 * 1024; + + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_OTH * K)) / (sizeof(float) * (K + MBLOCK_OTH)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float* b_pannel = + static_cast(framework::CPUContext::Context()->get_work_space( + K * (xmax - x0) * sizeof(float))); + if (!transB) { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_OTH) { + unsigned int ymax = y + MBLOCK_OTH; + if (ymax > M) { + ymax = M; + } + float* c_ptr0 = C + y * N + x0; + float* c_ptr1 = c_ptr0 + N; + float* c_ptr2 = c_ptr1 + N; + float* c_ptr3 = c_ptr2 + N; + float* c_ptr4 = c_ptr3 + N; + float* c_ptr5 = c_ptr4 + N; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + float* pout4 = c_ptr4; + float* pout5 = c_ptr5; + + float bias_local[6] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + bias_local[4] = bias[y + 4]; + bias_local[5] = bias[y + 5]; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + float cout4[NBLOCK]; + float cout5[NBLOCK]; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 5) >= ymax) { + switch ((y + 5) - ymax) { + case 4: + c_ptr1 = cout1; + case 3: + c_ptr2 = cout2; + case 2: + c_ptr3 = cout3; + case 1: + c_ptr4 = cout4; + case 0: + c_ptr5 = cout5; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + pout4 = c_ptr4; + pout5 = c_ptr5; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + c_ptr4 = cout4; + c_ptr5 = cout5; + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + // sgemm 6x8 + "vld1.32 {d2-d4}, [%[bias_ptr]] @ load bias 6 elements\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "pld [%[a_ptr]] @ preload a\n" + "vdup.i32 q12,d4[0] @ out40=0\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.i32 q13,d4[0] @ out41=0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.i32 q14,d4[1] @ out50=0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.i32 q15,d4[1] @ out51=0\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.i32 q4, d2[0] @ out00=0\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.i32 q5, d2[0] @ out01=0\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vdup.i32 q6, d2[1] @ out10=0\n" + "pld [%[a_ptr], #192] @ preload a\n" + "vdup.i32 q7, d2[1] @ out11=0\n" + "pld [%[b_ptr], #192] @ preload a\n" + "vdup.i32 q8, d3[0] @ out20=0\n" + "pld [%[a_ptr], #256] @ preload a\n" + "vdup.i32 q9, d3[0] @ out21=0\n" + "pld [%[b_ptr], #256] @ preload a\n" + "vdup.i32 q10,d3[1] @ out30=0\n" + "pld [%[b_ptr], #320] @ preload b\n" + "vdup.i32 q11,d3[1] @ out31=0\n" + "pld [%[b_ptr], #384] @ preload b\n" + "cmp %[k], #0 @ check weather k is " + "bigger than 0\n" + "beq 0f @ jump to tail\n" + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4, a5, and next " + "a0, a1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 1 */ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + /*"pld [%[a_ptr], #64] @ preload a\n"*/ + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + /*"pld [%[b_ptr], #192]\n"*/ + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4, a5, a0, a1\n" + /* Unroll 2 */ + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + /*"pld [%[a_ptr], #240] @ preload\n"*/ + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + /*"pld [%[b_ptr], #208]\n"*/ + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3 */ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "bne 1b @ jump to main loop\n" + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = 1\n" + /* Unroll 0*/ + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a4,5, a0, a1\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1*/ + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a0~a3\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d0-d1}, [%[a_ptr] :64]! @ load a4,a5, a0,a1\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vld1.32 {d4-d5}, [%[b_ptr] :128]! @ load b1\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d2-d3}, [%[a_ptr] :64]! @ load a2~a5\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + /* Unroll 3*/ + "vmla.f32 q4, q2, d1[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d1[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d2[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d2[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d3[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d3[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d1[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d1[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d2[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d2[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d3[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d3[1] @ out11 += b2 * a5\n" + "b 2f\n" + /* tails==1 final tail*/ + "3: @ tail=1\n" + "vmla.f32 q4, q2, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d2}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d0[1] @ out1 += b1 * a1\n" + "vld1.32 {d6-d7}, [%[b_ptr] :128]! @ load b2\n" + "vmla.f32 q8, q2, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d2[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d2[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d0[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d0[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d1[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d1[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d2[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d2[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q4, q2, d3[0] @ out0 += b1 * a0\n" + "vmla.f32 q6, q2, d3[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d0[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d0[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d1[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d1[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d3[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d3[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d0[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d0[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d1[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d1[1] @ out11 += b2 * a5\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q4, q2, d2[0] @ out0 += b1 * a0\n" + "vld1.32 {d0}, [%[a_ptr] :64]! @ load a4,a5\n" + "vmla.f32 q6, q2, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q8, q2, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q10, q2, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q12, q2, d0[0] @ out4 += b1 * a4\n" + "vmla.f32 q14, q2, d0[1] @ out5 += b1 * a5\n" + "vmla.f32 q5, q3, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q7, q3, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q9, q3, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q11, q3, d3[1] @ out9 += b2 * a3\n" + "vmla.f32 q13, q3, d0[0] @ out10 += b2 * a4\n" + "vmla.f32 q15, q3, d0[1] @ out11 += b2 * a5\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q4, q4, q0 @ for relu\n" + "vmax.f32 q5, q5, q0 @ for relu\n" + "vmax.f32 q6, q6, q0 @ for relu\n" + "vmax.f32 q7, q7, q0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d8-d11}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d12-d15}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d16-d19}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d20-d23}, [%[c_ptr3]]! @ store r3\n" + "vst1.32 {d24-d27}, [%[c_ptr4]]! @ store r4\n" + "vst1.32 {d28-d31}, [%[c_ptr5]]! @ store r5\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), [c_ptr4] "+r"(c_ptr4), + [c_ptr5] "+r"(c_ptr5), [k] "+r"(k), [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + *pout4++ = cout4[i]; + *pout5++ = cout5[i]; + } + } + } + } + } +} + +void sgemm_conv_4x8(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool transB) { + const int threads = framework::CPUContext::Context()->get_thread_num(); + int l2_size = + framework::CPUContext::Context()->get_l2_cache_size() / sizeof(float); + int l2_cache = l2_size > 0 ? l2_size : 512 * 1024; + + //! MBLOCK * x (result) + MBLOCK * k (A) + x * k (B) = l2 + int x_block = + (l2_cache - (MBLOCK_A73 * K)) / (sizeof(float) * (K + MBLOCK_A73)); + x_block /= NBLOCK; + x_block *= NBLOCK; + int x_num = (N + (x_block - 1)) / x_block; + x_block = (N + x_num - 1) / x_num; + x_block = (x_block + NBLOCK - 1) / NBLOCK; + x_block *= NBLOCK; + x_block = x_block < NBLOCK ? NBLOCK : x_block; + + int k_pre = ((K + KBLOCK - 1) / KBLOCK) - 1; + int tail_pre = (K & (KBLOCK - 1)); + if (tail_pre == 0) { + tail_pre = KBLOCK; + } + + bool flag_p_remain = false; + int remain = 0; + + //! apanel is pre_compute outside gemm + for (unsigned int x0 = 0; x0 < N; x0 += x_block) { + unsigned int xmax = x0 + x_block; + if (xmax > N) { + xmax = N; + } + int bblocks = (xmax - x0 + NBLOCK - 1) / NBLOCK; + remain = xmax - x0 - (bblocks - 1) * NBLOCK; + if (remain > 0) { + flag_p_remain = true; + } + //! load bpanel + float* b_pannel = + static_cast(framework::CPUContext::Context()->get_work_space( + K * (xmax - x0) * sizeof(float))); + + if (!transB) { + loadb(b_pannel, B, N, 0, K, x0, xmax); + } +#pragma omp parallel for num_threads(threads) + for (unsigned int y = 0; y < M; y += MBLOCK_A73) { + unsigned int ymax = y + MBLOCK_A73; + if (ymax > M) { + ymax = M; + } + + float cout0[NBLOCK]; + float cout1[NBLOCK]; + float cout2[NBLOCK]; + float cout3[NBLOCK]; + + float bias_local[4] = {0}; + if (is_bias) { + bias_local[0] = bias[y]; + bias_local[1] = bias[y + 1]; + bias_local[2] = bias[y + 2]; + bias_local[3] = bias[y + 3]; + } + + float* c_ptr0 = C + y * N + x0; + float* c_ptr1 = c_ptr0 + N; + float* c_ptr2 = c_ptr1 + N; + float* c_ptr3 = c_ptr2 + N; + + float* pout0 = c_ptr0; + float* pout1 = c_ptr1; + float* pout2 = c_ptr2; + float* pout3 = c_ptr3; + + const float* a_ptr_l = A_packed + y * K; + const float* b_ptr = b_pannel; + for (int xb = 0; xb < bblocks; xb++) { + if ((y + 3) >= ymax) { + switch ((y + 3) - ymax) { + case 2: + c_ptr1 = cout1; + case 1: + c_ptr2 = cout1; + case 0: + c_ptr3 = cout1; + default: + break; + } + } + if (flag_p_remain && (xb == bblocks - 1)) { + pout0 = c_ptr0; + pout1 = c_ptr1; + pout2 = c_ptr2; + pout3 = c_ptr3; + + c_ptr0 = cout0; + c_ptr1 = cout1; + c_ptr2 = cout2; + c_ptr3 = cout3; + } + const float* a_ptr = a_ptr_l; + int tails = tail_pre; + int k = k_pre; + asm volatile( + "vld1.32 {d4-d5}, [%[bias_ptr]] @ load bias\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load a0~a3\n" + "vdup.32 q8, d4[0] @ add bias to out00\n" + "pld [%[a_ptr]] @ preload a, 64byte\n" + "vdup.32 q9, d4[0] @ add bias to out01\n" + "pld [%[b_ptr]] @ preload b\n" + "vdup.32 q10, d4[1] @ add bias to out10\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vdup.32 q11, d4[1] @ add bias to out11\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load b1\n" + "vdup.32 q12, d5[0] @ add bias to out20\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vdup.32 q13, d5[0] @ add bias to out21\n" + "pld [%[a_ptr], #128] @ preload a\n" + "vdup.32 q14, d5[1] @ add bias to out30\n" + "pld [%[b_ptr], #128] @ preload b\n" + "vdup.32 q15, d5[1] @ add bias to out31\n" + "pld [%[b_ptr], #192] @ preload b\n" + "cmp %[k], #0 @ check weather k is " + "bigger than 0\n" + "beq 0f @ jump to tail\n" + + "1: @ main loop for k\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 1 */ + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "pld [%[b_ptr], #64] @ preload b\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + /* Unroll 2 */ + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vld1.32 {d0-d3}, [%[a_ptr] :128]! @ load next a0~a3\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" + "pld [%[a_ptr], #64] @ preload a\n" + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "subs %[k], %[k], #1 @ k--\n" + "bne 1b @ jump to main loop\n" + + "0: @ process tail\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "beq 3f @ jump to tail = 1\n" + /* Unroll 0*/ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" // b1*a1 + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + "beq 4f @ jump to tail==2\n" + /* Unroll 1 */ + "vld1.32 {d8-d11}, [%[b_ptr] :128]! @ load next b1, b2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" // b6*a2 + "vld1.32 {d4-d7}, [%[a_ptr] :128]! @ load next 2xa0~a3\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "subs %[tails], %[tails], #1 @ tail--\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out6 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out7 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out8 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out9 += b2 * a3\n" + "beq 5f @ jump to tail==3\n" + /* Unroll 2 */ + "vld1.32 {d12-d15}, [%[b_ptr] :128]! @ load next b1,b2\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" // b11 + // * + // a3 + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /* Unroll 3 */ + "vmla.f32 q8, q6, d6[0] @ out0 += b1 * a0\n" // b16 + // * + // a4 + "vmla.f32 q10, q6, d6[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d7[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d7[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d6[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d6[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d7[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d7[1] @ out7 += b2 * a3\n" + "b 2f\n" + /* tails==1 final tail */ + "3: @ tail=1\n" + "vmla.f32 q8, q4, d0[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d0[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d1[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d1[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d0[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d0[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d1[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d1[1] @ out7 += b2 * a3\n" + /*aptr - 16 */ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "b 2f @ jump to end\n" + /* tails==2 final tail*/ + "4: @ tail == 2\n" + "vmla.f32 q8, q6, d2[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q6, d2[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q6, d3[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q6, d3[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q7, d2[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q7, d2[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q7, d3[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q7, d3[1] @ out7 += b2 * a3\n" + "b 2f @ jump to end\n" + /* tails==3 final tail*/ + "5: @ tail=3\n" + "vmla.f32 q8, q4, d4[0] @ out0 += b1 * a0\n" + "vmla.f32 q10, q4, d4[1] @ out1 += b1 * a1\n" + "vmla.f32 q12, q4, d5[0] @ out2 += b1 * a2\n" + "vmla.f32 q14, q4, d5[1] @ out3 += b1 * a3\n" + "vmla.f32 q9, q5, d4[0] @ out4 += b2 * a0\n" + "vmla.f32 q11, q5, d4[1] @ out5 += b2 * a1\n" + "vmla.f32 q13, q5, d5[0] @ out6 += b2 * a2\n" + "vmla.f32 q15, q5, d5[1] @ out7 += b2 * a3\n" + /*aptr - 16*/ + "sub %[a_ptr], %[a_ptr], #16 @ tail--\n" + "2: @ check relu\n" + "cmp %[relu], #0 @ check if has relu\n" + "ble 6f @ skip relu if relu <= 0\n" + "vmov.u32 q0, #0 @ for relu\n" + "vmax.f32 q8, q8, q0 @ for relu\n" + "vmax.f32 q9, q9, q0 @ for relu\n" + "vmax.f32 q10, q10, q0 @ for relu\n" + "vmax.f32 q11, q11, q0 @ for relu\n" + "vmax.f32 q12, q12, q0 @ for relu\n" + "vmax.f32 q13, q13, q0 @ for relu\n" + "vmax.f32 q14, q14, q0 @ for relu\n" + "vmax.f32 q15, q15, q0 @ for relu\n" + "6: @ store result\n" + "vst1.32 {d16-d19}, [%[c_ptr0]]! @ store r0\n" + "vst1.32 {d20-d23}, [%[c_ptr1]]! @ store r1\n" + "vst1.32 {d24-d27}, [%[c_ptr2]]! @ store r2\n" + "vst1.32 {d28-d31}, [%[c_ptr3]]! @ store r3\n" + : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [c_ptr0] "+r"(c_ptr0), + [c_ptr1] "+r"(c_ptr1), [c_ptr2] "+r"(c_ptr2), + [c_ptr3] "+r"(c_ptr3), [k] "+r"(k), [tails] "+r"(tails) + : [bias_ptr] "r"(bias_local), [relu] "r"(is_relu) + : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", + "q11", "q12", "q13", "q14", "q15", "cc", "memory"); + + if (flag_p_remain && (xb == bblocks - 1)) { + for (int i = 0; i < remain; ++i) { + *pout0++ = cout0[i]; + *pout1++ = cout1[i]; + *pout2++ = cout2[i]; + *pout3++ = cout3[i]; + } + } + } + } + } +} + +#endif //__aarch64__ +/// a: m*k b: k*n c: m*n +void sgemm_prepack(const float *A_packed, const float *B, const float *bias, + float *C, int M, int N, int K, bool is_bias, bool is_relu, + bool is_transB, ARMArch arch) { +#ifdef __aarch64__ + sgemm_conv_8x12(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB); +#else // armv7 + if (arch == A73) { + sgemm_conv_4x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB); + } else { + sgemm_conv_6x8(A_packed, B, bias, C, M, N, K, is_bias, is_relu, is_transB); + } +#endif // arm64 +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // CONV_OP +#endif // __ARM_NEON__ diff --git a/src/operators/math/gemm/gemm1x1s1.h b/src/operators/math/gemm/gemm1x1s1.h new file mode 100644 index 0000000000000000000000000000000000000000..e7cae8bf10db7bee27be7de7ec216c72bdad5b69 --- /dev/null +++ b/src/operators/math/gemm/gemm1x1s1.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#pragma once +#include "framework/tensor.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +#ifdef __aarch64__ +const int MBLOCK = 8; +const int NBLOCK = 12; +const int KBLOCK = 4; +inline int get_hblock(ARMArch arch) { return MBLOCK; } +#else +const int MBLOCK_A73 = 4; +const int MBLOCK_OTH = 6; +const int NBLOCK = 8; +const int KBLOCK = 4; + +inline int get_hblock(ARMArch arch) { + if (arch == A73) { + return MBLOCK_A73; + } else { + return MBLOCK_OTH; + } +} +#endif // __aarch64__ + +void gemm1x1s1_transform_weight(const framework::Tensor& weight, + const framework::Tensor& output, + framework::Tensor* trans_weight, + const int group, ARMArch arch); + +void sgemm_prepack(const float* A_packed, const float* B, const float* bias, + float* C, int M, int N, int K, bool is_bias, bool is_relu, + bool is_transB, ARMArch arch); + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // CONV_OP diff --git a/src/operators/op_param.h b/src/operators/op_param.h index e438a5f35472405a471c5d6a490a9bcef567141f..4df7b5b173f49e79607ee23c816acfea90f5ea15 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -467,6 +467,7 @@ class ConvParam : public OpParam { EXEC_SLIDINGWINDOW3x3_FLOAT, EXEC_SLIDINGWINDOW5x5_FLOAT, EXEC_SLIDINGWINDOW7x7_FLOAT, + EXEC_GEMM1x1s1_FLOAT, }; ExecMode &ExecMode() const { return exec_mode_; }