diff --git a/src/common/types.cpp b/src/common/types.cpp index cbaf289e27ffcd34eff7b113e06219a595de5257..7c0f2924138263e6e14e58d678adf03098a8f0a1 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -83,6 +83,8 @@ const char *G_OP_TYPE_LOGICAL_NOT = "logical_not"; const char *G_OP_TYPE_LOGICAL_XOR = "logical_xor"; const char *G_OP_TYPE_WRITE_TO_ARRAY = "write_to_array"; const char *G_OP_TYPE_READ_FROM_ARRAY = "read_from_array"; +const char *G_OP_TYPE_IS_EMPTY = "is_empty"; +const char *G_OP_TYPE_INCREMENT = "increment"; const char *G_OP_TYPE_QUANTIZE = "quantize"; const char *G_OP_TYPE_DEQUANTIZE = "dequantize"; @@ -199,6 +201,8 @@ std::unordered_map< {G_OP_TYPE_LOGICAL_NOT, {{"X"}, {"Out"}}}, {G_OP_TYPE_WRITE_TO_ARRAY, {{"X", "I"}, {"Out"}}}, {G_OP_TYPE_READ_FROM_ARRAY, {{"X", "I"}, {"Out"}}}, + {G_OP_TYPE_IS_EMPTY, {{"X"}, {"Out"}}}, + {G_OP_TYPE_INCREMENT, {{"X"}, {"Out"}}}, {G_OP_TYPE_SLICE, {{"Input"}, {"Out"}}}, {G_OP_TYPE_ANCHOR_GENERATOR, {{"Input"}, {"Anchors", "Variances"}}}, {G_OP_TYPE_GENERATE_PROPOSALS, diff --git a/src/common/types.h b/src/common/types.h index 267015539fa7d3ed7868f841ce22a83ed665e972..2f23f17d0d694c29b2c37e03a59f1abe0c3c5324 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -172,6 +172,8 @@ extern const char *G_OP_TYPE_LOGICAL_NOT; extern const char *G_OP_TYPE_LOGICAL_XOR; extern const char *G_OP_TYPE_WRITE_TO_ARRAY; extern const char *G_OP_TYPE_READ_FROM_ARRAY; +extern const char *G_OP_TYPE_IS_EMPTY; +extern const char *G_OP_TYPE_INCREMENT; extern const char *G_OP_TYPE_QUANTIZE; extern const char *G_OP_TYPE_DEQUANTIZE; diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 0727c0cb04ec93047e612863f23dd92cb131cbed..95cf5e1f7fe7084defa7d197b1cac192ab169c1d 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -306,6 +306,12 @@ LOAD_OP1(write_to_array, CPU); #ifdef READ_FROM_ARRAY_OP LOAD_OP1(read_from_array, CPU); #endif +#ifdef IS_EMPTY_OP +LOAD_OP1(is_empty, CPU); +#endif +#ifdef INCREMENT_OP +LOAD_OP1(increment, CPU); +#endif #ifdef ANCHOR_GENERATOR_OP LOAD_OP1(anchor_generator, CPU); #endif diff --git a/src/operators/increment_op.cpp b/src/operators/increment_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b2cd0462e5392f20deb86231b02745458a83b3e --- /dev/null +++ b/src/operators/increment_op.cpp @@ -0,0 +1,48 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef INCREMENT_OP + +#include "operators/increment_op.h" +#include "framework/op_proto_maker.h" +#include "framework/op_registry.h" + +namespace paddle_mobile { +namespace operators { + +template +void IncrementOp::InferShape() const { + auto input = this->param_.InputX(); + auto out = this->param_.Out(); + PADDLE_MOBILE_ENFORCE(input->numel() == 1, "input's numel should be 1"); + out->Resize(input->dims()); + out->set_lod(input->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(increment, ops::IncrementOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#ifdef PADDLE_MOBILE_CL +#endif + +#endif diff --git a/src/operators/increment_op.h b/src/operators/increment_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5212630cc4d01ca6d432f5e340d3cf0fd89782b5 --- /dev/null +++ b/src/operators/increment_op.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef INCREMENT_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/increment_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +template +class IncrementOp + : public framework::OperatorWithKernel, + IncrementKernel> { + public: + IncrementOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + IncrementKernel>( + type, inputs, outputs, attrs, scope) {} + + void InferShape() const override; + + protected: +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/is_empty_op.cpp b/src/operators/is_empty_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..654b998ebdfa6f6b0f40401a32ab5968c9dfeee1 --- /dev/null +++ b/src/operators/is_empty_op.cpp @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef IS_EMPTY_OP + +#include "operators/is_empty_op.h" +#include "framework/op_proto_maker.h" +#include "framework/op_registry.h" + +namespace paddle_mobile { +namespace operators { + +template +void IsEmptyOp::InferShape() const { + auto out = this->param_.Out(); + out->Resize({1}); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(is_empty, ops::IsEmptyOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#ifdef PADDLE_MOBILE_CL +#endif + +#endif diff --git a/src/operators/is_empty_op.h b/src/operators/is_empty_op.h new file mode 100644 index 0000000000000000000000000000000000000000..45af2646a7a8a2691868ea204a8602a34898412d --- /dev/null +++ b/src/operators/is_empty_op.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef IS_EMPTY_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/is_empty_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +using std::string; +template +class IsEmptyOp + : public framework::OperatorWithKernel, + IsEmptyKernel> { + public: + IsEmptyOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, + const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + IsEmptyKernel>( + type, inputs, outputs, attrs, scope) {} + + void InferShape() const override; + + protected: +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/increment_kernel.cpp b/src/operators/kernel/arm/increment_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27fd48d084a1c29289dd6e8755cee860208d12f7 --- /dev/null +++ b/src/operators/kernel/arm/increment_kernel.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef INCREMENT_OP + +#include "operators/kernel/increment_kernel.h" +#include + +namespace paddle_mobile { +namespace operators { + +template <> +bool IncrementKernel::Init(IncrementParam *param) { + return true; +} + +template <> +void IncrementKernel::Compute(const IncrementParam ¶m) { + IncrementCompute(param); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/is_empty_kernel.cpp b/src/operators/kernel/arm/is_empty_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..070d3d16d7813584c490c868333d69a9a11afde9 --- /dev/null +++ b/src/operators/kernel/arm/is_empty_kernel.cpp @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef INCREMENT_OP + +#include "operators/kernel/is_empty_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool IsEmptyKernel::Init(IsEmptyParam *param) { + return true; +} + +template <> +void IsEmptyKernel::Compute(const IsEmptyParam ¶m) { + const framework::Tensor *input = param.InputX(); + framework::Tensor *out = param.Out(); + out->mutable_data()[0] = input->numel() == 0; +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/increment_arm_func.h b/src/operators/kernel/central-arm-func/increment_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..44465b2a2f10ad0ca9cb2b6166d14429197a1e30 --- /dev/null +++ b/src/operators/kernel/central-arm-func/increment_arm_func.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef INCREMENT_OP + +#pragma once + +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +void IncrementCompute(const IncrementParam ¶m) { + const framework::Tensor *input = param.InputX(); + framework::Tensor *out = param.Out(); + int step = param.Step(); + + out->mutable_data

(); + const P *input_data = input->data

(); + P *out_data = out->data

(); + *out_data = *input_data + step; +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/increment_kernel.h b/src/operators/kernel/increment_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..43a930c1b9be512714d253db0966ad171ba4068c --- /dev/null +++ b/src/operators/kernel/increment_kernel.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef INCREMENT_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class IncrementKernel + : public framework::OpKernelBase> { + public: + void Compute(const IncrementParam ¶m); + bool Init(IncrementParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/is_empty_kernel.h b/src/operators/kernel/is_empty_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0a6806d087b2bc06e2de77b9a133ad599ba9c3e5 --- /dev/null +++ b/src/operators/kernel/is_empty_kernel.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef IS_EMPTY_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class IsEmptyKernel + : public framework::OpKernelBase> { + public: + void Compute(const IsEmptyParam ¶m); + bool Init(IsEmptyParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 00d2a47d1be1edd2b27721c064faececb9672b29..869a61089621e8ed436944c26bc3cffc78159f46 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -415,6 +415,7 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, } } +#if __ARM_NEON #if __aarch64__ void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, float *buffer) { @@ -538,6 +539,7 @@ void Gemm::PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, } } #endif // __aarch64__ +#endif // __ARM_NEON // 分块矩阵乘法 void Gemm::InnerKernel(int mc, int nc, float alpha, const float *a, @@ -688,42 +690,7 @@ void Gemm::InnerKernelWithPRelu(int mc, int nc, const float *a, const float *b, #if __ARM_NEON #if __aarch64__ -void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { - // init C - float32x4_t cv0 = vdupq_n_f32(0.0); - float32x4_t cv1 = vdupq_n_f32(0.0); - float32x4_t cv2 = vdupq_n_f32(0.0); - float32x4_t cv3 = vdupq_n_f32(0.0); - - float32x4_t av; - float32x4_t bv; - - float32x2_t av01; - float32x2_t av23; - - for (int p = 0; p < k; p += 1) { - av = vld1q_f32(a); - bv = vld1q_f32(b); - - av01 = vget_low_f32(av); - cv0 = vmlaq_lane_f32(cv0, bv, av01, 0); - cv1 = vmlaq_lane_f32(cv1, bv, av01, 1); - av23 = vget_high_f32(av); - cv2 = vmlaq_lane_f32(cv2, bv, av23, 0); - cv3 = vmlaq_lane_f32(cv3, bv, av23, 1); - - a += MR; - b += NR; - } - - vst1q_f32(c, cv0); - vst1q_f32(c + ldc, cv1); - vst1q_f32(c + 2 * ldc, cv2); - vst1q_f32(c + 3 * ldc, cv3); - // float32x4x4_t cv = {cv0, cv1, cv2, cv3}; -} - -void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { +void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { // init C float32x4_t cv0 = vdupq_n_f32(0.0); float32x4_t cv1 = vdupq_n_f32(0.0); @@ -733,6 +700,10 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { float32x4_t cv5 = vdupq_n_f32(0.0); float32x4_t cv6 = vdupq_n_f32(0.0); float32x4_t cv7 = vdupq_n_f32(0.0); + float32x4_t cv8 = vdupq_n_f32(0.0); + float32x4_t cv9 = vdupq_n_f32(0.0); + float32x4_t cv10 = vdupq_n_f32(0.0); + float32x4_t cv11 = vdupq_n_f32(0.0); float32x4_t av; float32x4_t bv0; @@ -740,23 +711,31 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { float32x2_t av01; float32x2_t av23; + float32x2_t av45; for (int p = 0; p < k; p += 1) { av = vld1q_f32(a); + av01 = vget_low_f32(av); + av23 = vget_high_f32(av); + av45 = vld1_f32(a + 4); bv0 = vld1q_f32(b); bv1 = vld1q_f32(b + 4); - av01 = vget_low_f32(av); cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0); cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0); cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1); cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1); - av23 = vget_high_f32(av); + cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0); cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0); cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1); cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1); + cv8 = vmlaq_lane_f32(cv8, bv0, av45, 0); + cv9 = vmlaq_lane_f32(cv9, bv1, av45, 0); + cv10 = vmlaq_lane_f32(cv10, bv0, av45, 1); + cv11 = vmlaq_lane_f32(cv11, bv1, av45, 1); + a += MR; b += NR; } @@ -769,310 +748,615 @@ void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { vst1q_f32(c + 2 * ldc + 4, cv5); vst1q_f32(c + 3 * ldc, cv6); vst1q_f32(c + 3 * ldc + 4, cv7); + vst1q_f32(c + 4 * ldc, cv8); + vst1q_f32(c + 4 * ldc + 4, cv9); + vst1q_f32(c + 5 * ldc, cv10); + vst1q_f32(c + 5 * ldc + 4, cv11); } -// 分块矩阵乘法结果回写 -// C = A * B -void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { - int nc1 = nc / 4; - int _nc1 = nc % 4; +void Gemm::AddDot8x12(int k, const float *a, const float *b, float *c, + int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k; + int step = 4 * ldc; + asm volatile( + "dup v5.4s, wzr \n\t" + "dup v6.4s, wzr \n\t" + "dup v7.4s, wzr \n\t" + "dup v8.4s, wzr \n\t" + "dup v9.4s, wzr \n\t" + "dup v10.4s, wzr \n\t" + "dup v11.4s, wzr \n\t" + "dup v12.4s, wzr \n\t" + "dup v13.4s, wzr \n\t" + "dup v14.4s, wzr \n\t" + "dup v15.4s, wzr \n\t" + "dup v16.4s, wzr \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - } - } - } -} + "dup v17.4s, wzr \n\t" + "dup v18.4s, wzr \n\t" + "dup v19.4s, wzr \n\t" + "dup v20.4s, wzr \n\t" + "dup v21.4s, wzr \n\t" + "dup v22.4s, wzr \n\t" + "dup v23.4s, wzr \n\t" + "dup v24.4s, wzr \n\t" + "dup v25.4s, wzr \n\t" + "dup v26.4s, wzr \n\t" + "dup v27.4s, wzr \n\t" + "dup v28.4s, wzr \n\t" -// C = alpha * A * B + beta * C -void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} + "subs %[kc1], %[kc1], #1 \n\t" + "blt 2f \n\t" + "1: \n\t" -// C = A * B + C -void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { - int nc1 = nc / 4; - int _nc1 = nc % 4; + "prfm pldl1keep, [%[a_ptr], #32] \n\t" + "prfm pldl1keep, [%[b_ptr], #48] \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t cv1; - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv1 = vld1q_f32(C_ptr); - cv = vaddq_f32(cv, cv1); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv1 = vld1q_f32(C_ptr); - cv = vaddq_f32(cv, cv1); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - } - } - } -} -// C = A * B + bias -void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, - float *bias) { - int nc1 = nc / 4; - int _nc1 = nc % 4; + "ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t" + "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t biasv; - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - biasv = vld1q_dup_f32(bias + i); - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - C_ptr++; - } - } - } + "fmla v5.4s, v2.4s, v0.s[0] \n\t" + "fmla v6.4s, v3.4s, v0.s[0] \n\t" + "fmla v7.4s, v4.4s, v0.s[0] \n\t" + "fmla v8.4s, v2.4s, v0.s[1] \n\t" + "fmla v9.4s, v3.4s, v0.s[1] \n\t" + "fmla v10.4s, v4.4s, v0.s[1] \n\t" + "fmla v11.4s, v2.4s, v0.s[2] \n\t" + "fmla v12.4s, v3.4s, v0.s[2] \n\t" + "fmla v13.4s, v4.4s, v0.s[2] \n\t" + "fmla v14.4s, v2.4s, v0.s[3] \n\t" + "fmla v15.4s, v3.4s, v0.s[3] \n\t" + "fmla v16.4s, v4.4s, v0.s[3] \n\t" + + "fmla v17.4s, v2.4s, v1.s[0] \n\t" + "fmla v18.4s, v3.4s, v1.s[0] \n\t" + "fmla v19.4s, v4.4s, v1.s[0] \n\t" + "fmla v20.4s, v2.4s, v1.s[1] \n\t" + "fmla v21.4s, v3.4s, v1.s[1] \n\t" + "fmla v22.4s, v4.4s, v1.s[1] \n\t" + "fmla v23.4s, v2.4s, v1.s[2] \n\t" + "fmla v24.4s, v3.4s, v1.s[2] \n\t" + "fmla v25.4s, v4.4s, v1.s[2] \n\t" + "fmla v26.4s, v2.4s, v1.s[3] \n\t" + "fmla v27.4s, v3.4s, v1.s[3] \n\t" + "fmla v28.4s, v4.4s, v1.s[3] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 1b \n\t" + "2: \n\t" + + "st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t" + "st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t" + "st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" + "st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t" + "st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t" + "st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t" + "st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" + "st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [step] "r"(step) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28"); } -// C = A * B + C, relu(C) -void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { - int nc1 = nc / 4; - int _nc1 = nc % 4; +void Gemm::AddDot6x16(int k, const float *a, const float *b, float *c, + int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k; + int step = 4 * ldc; + int step1 = 4 * 6; + asm volatile( - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t cv1; - float32x4_t zero = vdupq_n_f32(0.0); - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv1 = vld1q_f32(C_ptr); - cv = vaddq_f32(cv, cv1); - cv = vmaxq_f32(cv, zero); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv1 = vld1q_f32(C_ptr); - cv = vaddq_f32(cv, cv1); - cv = vmaxq_f32(cv, zero); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - } - } - } + "dup v6.4s, wzr \n\t" + "dup v7.4s, wzr \n\t" + "dup v8.4s, wzr \n\t" + "dup v9.4s, wzr \n\t" + "dup v10.4s, wzr \n\t" + "dup v11.4s, wzr \n\t" + "dup v12.4s, wzr \n\t" + "dup v13.4s, wzr \n\t" + + "dup v14.4s, wzr \n\t" + "dup v15.4s, wzr \n\t" + "dup v16.4s, wzr \n\t" + "dup v17.4s, wzr \n\t" + "dup v18.4s, wzr \n\t" + "dup v19.4s, wzr \n\t" + "dup v20.4s, wzr \n\t" + "dup v21.4s, wzr \n\t" + + "dup v22.4s, wzr \n\t" + "dup v23.4s, wzr \n\t" + "dup v24.4s, wzr \n\t" + "dup v25.4s, wzr \n\t" + "dup v26.4s, wzr \n\t" + "dup v27.4s, wzr \n\t" + "dup v28.4s, wzr \n\t" + "dup v29.4s, wzr \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt 2f \n\t" + "1: \n\t" + + "prfm pldl1keep, [%[a_ptr], #24] \n\t" + "prfm pldl1keep, [%[b_ptr], #64] \n\t" + + "ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t" + "ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t" + + "fmla v6.4s, v2.4s, v0.s[0] \n\t" + "fmla v7.4s, v3.4s, v0.s[0] \n\t" + "fmla v8.4s, v4.4s, v0.s[0] \n\t" + "fmla v9.4s, v5.4s, v0.s[0] \n\t" + + "fmla v10.4s, v2.4s, v0.s[1] \n\t" + "fmla v11.4s, v3.4s, v0.s[1] \n\t" + "fmla v12.4s, v4.4s, v0.s[1] \n\t" + "fmla v13.4s, v5.4s, v0.s[1] \n\t" + + "fmla v14.4s, v2.4s, v0.s[2] \n\t" + "fmla v15.4s, v3.4s, v0.s[2] \n\t" + "fmla v16.4s, v4.4s, v0.s[2] \n\t" + "fmla v17.4s, v5.4s, v0.s[2] \n\t" + + "fmla v18.4s, v2.4s, v0.s[3] \n\t" + "fmla v19.4s, v3.4s, v0.s[3] \n\t" + "fmla v20.4s, v4.4s, v0.s[3] \n\t" + "fmla v21.4s, v5.4s, v0.s[3] \n\t" + + "fmla v22.4s, v2.4s, v1.s[0] \n\t" + "fmla v23.4s, v3.4s, v1.s[0] \n\t" + "fmla v24.4s, v4.4s, v1.s[0] \n\t" + "fmla v25.4s, v5.4s, v1.s[0] \n\t" + + "fmla v26.4s, v2.4s, v1.s[1] \n\t" + "fmla v27.4s, v3.4s, v1.s[1] \n\t" + "fmla v28.4s, v4.4s, v1.s[1] \n\t" + "fmla v29.4s, v5.4s, v1.s[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 1b \n\t" + "2: \n\t" + + "st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t" + "st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" + "st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t" + "st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t" + "st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" + "st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"); +} + +#else + +void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k / 4; + int kc2 = k % 4; + int step = 4 * ldc; + asm volatile( + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt end_kc1_%= \n\t" + "loop_kc1_%=: \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #64] \n\t" + "vld1.32 {q0, q1}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q2, d0[1] \n\t" + "vmla.f32 q12, q2, d1[0] \n\t" + "vmla.f32 q13, q2, d1[1] \n\t" + "vmla.f32 q10, q3, d2[0] \n\t" + "vmla.f32 q11, q3, d2[1] \n\t" + "vmla.f32 q12, q3, d3[0] \n\t" + "vmla.f32 q13, q3, d3[1] \n\t" + "vld1.32 {q4, q5}, [%[a_ptr]]! \n\t" + "vld1.32 {q6, q7}, [%[b_ptr]]! \n\t" + "vmla.f32 q10, q6, d8[0] \n\t" + "vmla.f32 q11, q6, d8[1] \n\t" + "vmla.f32 q12, q6, d9[0] \n\t" + "vmla.f32 q13, q6, d9[1] \n\t" + "vmla.f32 q10, q7, d10[0] \n\t" + "vmla.f32 q11, q7, d10[1] \n\t" + "vmla.f32 q12, q7, d11[0] \n\t" + "vmla.f32 q13, q7, d11[1] \n\t" + "subs %[kc1], %[kc1], #1 \n\t" + "bge loop_kc1_%= \n\t" + "end_kc1_%=: \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "blt end_kc2_%= \n\t" + "loop_kc2_%=: \n\t" + "vld1.32 {q0}, [%[a_ptr]]! \n\t" + "vld1.32 {q1}, [%[b_ptr]]! \n\t" + "vmla.f32 q10, q1, d0[0] \n\t" + "vmla.f32 q11, q1, d0[1] \n\t" + "vmla.f32 q12, q1, d1[0] \n\t" + "vmla.f32 q13, q1, d1[1] \n\t" + "subs %[kc2], %[kc2], #1 \n\t" + "bge loop_kc2_%= \n\t" + "end_kc2_%=: \n\t" + + "mov r5, %[c] \n\t" + "mov r6, %[step] \n\t" + "vst1.32 {q10}, [r5], r6 \n\t" + "vst1.32 {q11}, [r5], r6 \n\t" + "vst1.32 {q12}, [r5], r6 \n\t" + "vst1.32 {q13}, [r5] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q10", "q11", "q12", "q13"); +} + +void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k / 4; + int kc2 = k % 4; + int step = 4 * ldc; + asm volatile( + "pld [%[a_ptr]] \n\t" + "pld [%[b_ptr]] \n\t" + + "vmov.f32 q8, #0.0 \n\t" + "vmov.f32 q9, #0.0 \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "vmov.f32 q14, #0.0 \n\t" + "vmov.f32 q15, #0.0 \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt end_kc1_%= \n\t" + "loop_kc1_%=: \n\t" + + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr], #64] \n\t" + + "vld1.32 {q0, q1}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + "vld1.32 {q4, q5}, [%[b_ptr]]! \n\t" + + "vmla.f32 q8, q2, d0[0] \n\t" + "vmla.f32 q9, q3, d0[0] \n\t" + "vmla.f32 q10, q2, d0[1] \n\t" + "vmla.f32 q11, q3, d0[1] \n\t" + "vmla.f32 q12, q2, d1[0] \n\t" + "vmla.f32 q13, q3, d1[0] \n\t" + "vmla.f32 q14, q2, d1[1] \n\t" + "vmla.f32 q15, q3, d1[1] \n\t" + + "vmla.f32 q8, q4, d2[0] \n\t" + "vmla.f32 q9, q5, d2[0] \n\t" + "vmla.f32 q10, q4, d2[1] \n\t" + "vmla.f32 q11, q5, d2[1] \n\t" + "vmla.f32 q12, q4, d3[0] \n\t" + "vmla.f32 q13, q5, d3[0] \n\t" + "vmla.f32 q14, q4, d3[1] \n\t" + "vmla.f32 q15, q5, d3[1] \n\t" + + "pld [%[b_ptr], #64] \n\t" + + "vld1.32 {q0, q1}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + "vld1.32 {q4, q5}, [%[b_ptr]]! \n\t" + + "vmla.f32 q8, q2, d0[0] \n\t" + "vmla.f32 q9, q3, d0[0] \n\t" + "vmla.f32 q10, q2, d0[1] \n\t" + "vmla.f32 q11, q3, d0[1] \n\t" + "vmla.f32 q12, q2, d1[0] \n\t" + "vmla.f32 q13, q3, d1[0] \n\t" + "vmla.f32 q14, q2, d1[1] \n\t" + "vmla.f32 q15, q3, d1[1] \n\t" + + "vmla.f32 q8, q4, d2[0] \n\t" + "vmla.f32 q9, q5, d2[0] \n\t" + "vmla.f32 q10, q4, d2[1] \n\t" + "vmla.f32 q11, q5, d2[1] \n\t" + "vmla.f32 q12, q4, d3[0] \n\t" + "vmla.f32 q13, q5, d3[0] \n\t" + "vmla.f32 q14, q4, d3[1] \n\t" + "vmla.f32 q15, q5, d3[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge loop_kc1_%= \n\t" + "end_kc1_%=: \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "blt end_kc2_%= \n\t" + "loop_kc2_%=: \n\t" + "vld1.32 {q0}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + "vmla.f32 q8, q2, d0[0] \n\t" + "vmla.f32 q9, q3, d0[0] \n\t" + "vmla.f32 q10, q2, d0[1] \n\t" + "vmla.f32 q11, q3, d0[1] \n\t" + "vmla.f32 q12, q2, d1[0] \n\t" + "vmla.f32 q13, q3, d1[0] \n\t" + "vmla.f32 q14, q2, d1[1] \n\t" + "vmla.f32 q15, q3, d1[1] \n\t" + "subs %[kc2], %[kc2], #1 \n\t" + "bge loop_kc2_%= \n\t" + "end_kc2_%=: \n\t" + + "mov r5, %[c] \n\t" + "mov r6, %[step] \n\t" + "vst1.32 {q8, q9}, [r5], r6 \n\t" + "vst1.32 {q10, q11}, [r5], r6 \n\t" + "vst1.32 {q12, q13}, [r5], r6 \n\t" + "vst1.32 {q14, q15}, [r5] \n\t" + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9", + "q10", "q11", "q12", "q13", "q14", "q15"); } -// C = A * B + bias, relu(C) -void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, - float *bias) { - int nc1 = nc / 4; - int _nc1 = nc % 4; +void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { + const float *a_ptr, *b_ptr; + a_ptr = a; + b_ptr = b; + int kc1 = k / 8; + int kc2 = k % 8; + int step = sizeof(float) * ldc; + asm volatile( + "pld [%[a_ptr]] \n\t" + "pld [%[a_ptr], #64] \n\t" + "pld [%[b_ptr]] \n\t" + "pld [%[b_ptr], #64] \n\t" + + "vmov.f32 q4, #0.0 \n\t" + "vmov.f32 q5, #0.0 \n\t" + "vmov.f32 q6, #0.0 \n\t" + "vmov.f32 q7, #0.0 \n\t" + "vmov.f32 q8, #0.0 \n\t" + "vmov.f32 q9, #0.0 \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "vmov.f32 q14, #0.0 \n\t" + "vmov.f32 q15, #0.0 \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt 2f \n\t" + "1: \n\t" + + "pld [%[a_ptr], #128] \n\t" + "pld [%[b_ptr], #128] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[a_ptr], #128] \n\t" + "pld [%[b_ptr], #128] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[a_ptr], #128] \n\t" + "pld [%[b_ptr], #128] \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t biasv; - float32x4_t zero = vdupq_n_f32(0.0); - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - biasv = vld1q_dup_f32(bias + i); - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - cv = vmaxq_f32(cv, zero); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - cv = vmaxq_f32(cv, zero); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - C_ptr++; - } - } - } -} + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" -// C = A * B + C,prelu(C) -void Gemm::WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, - float *p, std::string mode, float *bias, - float *bias1) { - int nc1 = nc / 4; - int _nc1 = nc % 4; + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t cv1; - float32x4_t biasv; - float32x4_t biasv1; - float32x4_t zero = vdupq_n_f32(0.0); - float32x4_t pv; - float *ptr = p; - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - biasv = vld1q_dup_f32(bias + i); - if (bias1 == nullptr) { - biasv1 = zero; - } else { - biasv1 = vld1q_dup_f32(bias1 + i); - } + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - cv = vaddq_f32(cv, biasv1); - cv = vmaxq_f32(cv, zero); - cv1 = vminq_f32(cv, zero); - if (mode == "channel") { - cv1 = vmulq_n_f32(cv1, ptr[i]); - } else if (mode == "element") { - pv = vld1q_f32(ptr); - cv1 = vmulq_f32(cv1, pv); - ptr = ptr + 4; - } else { - cv1 = vmulq_n_f32(cv1, ptr[0]); - } - cv = vaddq_f32(cv, cv1); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - cv = vaddq_f32(cv, biasv1); - cv = vmaxq_f32(cv, zero); - cv1 = vminq_f32(cv, zero); - if (mode == "channel") { - cv1 = vmulq_n_f32(cv1, ptr[i]); - } else if (mode == "element") { - pv = vld1q_f32(ptr); - cv1 = vmulq_f32(cv1, pv); - ptr = ptr + 4; - } else { - cv1 = vmulq_n_f32(cv1, ptr[0]); - } - cv = vaddq_f32(cv, cv1); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - C_ptr++; - } - } - } + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[a_ptr], #128] \n\t" + "pld [%[b_ptr], #128] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 1b \n\t" + "2: \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "blt 4f \n\t" + "3: \n\t" + + "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" + "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "bge 3b \n\t" + "4: \n\t" + + "mov r5, %[c] \n\t" + "mov r6, %[step] \n\t" + "vst1.32 {q4, q5}, [r5], r6 \n\t" + "vst1.32 {q6, q7}, [r5], r6 \n\t" + "vst1.32 {q8, q9}, [r5], r6 \n\t" + "vst1.32 {q10, q11}, [r5], r6 \n\t" + "vst1.32 {q12, q13}, [r5], r6 \n\t" + "vst1.32 {q14, q15}, [r5] \n\t" + + : + : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); } -// C = A * B, batchnorm(C) -void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias) { +#endif // __aarch64__ +#endif // __ARM_NEON + +#if __ARM_NEON +#if __aarch64__ + +// 分块矩阵乘法结果回写 +// C = A * B +void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 4; int _nc1 = nc % 4; float *c_ptr, *C_ptr; float32x4_t cv; - float32x4_t cv1; - float32x4_t bias; - float32x2_t scale; for (int i = 0; i < mc; ++i) { c_ptr = c + i * NC; C_ptr = C + i * ldc; - bias = vld1q_dup_f32(new_bias); - scale = vld1_dup_f32(new_scale); - new_bias++; - new_scale++; - float scale0 = vget_lane_f32(scale, 0); for (int j = 0; j < nc1; ++j) { cv = vld1q_f32(c_ptr); - cv = vmlaq_n_f32(bias, cv, scale0); vst1q_f32(C_ptr, cv); c_ptr += 4; C_ptr += 4; } if (_nc1 != 0) { cv = vld1q_f32(c_ptr); - cv = vmlaq_n_f32(bias, cv, scale0); if (_nc1 >= 1) { vst1q_lane_f32(C_ptr, cv, 0); C_ptr++; @@ -1083,43 +1367,37 @@ void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, } if (_nc1 >= 3) { vst1q_lane_f32(C_ptr, cv, 2); - C_ptr++; } } } } -// C = A * B, batchnorm(C), relu(C) -void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias) { +// C = alpha * A * B + beta * C +void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} + +// C = A * B + C +void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 4; int _nc1 = nc % 4; float *c_ptr, *C_ptr; float32x4_t cv; - float32x4_t bias; - float32x2_t scale; - float32x4_t zero = vdupq_n_f32(0.0); + float32x4_t cv1; for (int i = 0; i < mc; ++i) { c_ptr = c + i * NC; C_ptr = C + i * ldc; - bias = vld1q_dup_f32(new_bias); - scale = vld1_dup_f32(new_scale); - new_bias++; - new_scale++; - float scale0 = vget_lane_f32(scale, 0); for (int j = 0; j < nc1; ++j) { cv = vld1q_f32(c_ptr); - cv = vmlaq_n_f32(bias, cv, scale0); - cv = vmaxq_f32(cv, zero); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); vst1q_f32(C_ptr, cv); c_ptr += 4; C_ptr += 4; } if (_nc1 != 0) { cv = vld1q_f32(c_ptr); - cv = vmlaq_n_f32(bias, cv, scale0); - cv = vmaxq_f32(cv, zero); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); if (_nc1 >= 1) { vst1q_lane_f32(C_ptr, cv, 0); C_ptr++; @@ -1134,45 +1412,29 @@ void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, } } } - -// C = A * B, batchnorm(C),C = C + bias; relu(C) -void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias, float *bias) { +// C = A * B + bias +void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) { int nc1 = nc / 4; int _nc1 = nc % 4; - float *c_ptr, *C_ptr, *bias_ptr; + float *c_ptr, *C_ptr; float32x4_t cv; - float32x4_t nbias; - float32x2_t scale; float32x4_t biasv; - float32x4_t zero = vdupq_n_f32(0.0); for (int i = 0; i < mc; ++i) { c_ptr = c + i * NC; C_ptr = C + i * ldc; - bias_ptr = bias + i * ldc; - nbias = vld1q_dup_f32(new_bias); - scale = vld1_dup_f32(new_scale); - new_bias++; - new_scale++; - float scale0 = vget_lane_f32(scale, 0); + biasv = vld1q_dup_f32(bias + i); for (int j = 0; j < nc1; ++j) { cv = vld1q_f32(c_ptr); - biasv = vld1q_f32(bias_ptr); - cv = vmlaq_n_f32(nbias, cv, scale0); cv = vaddq_f32(cv, biasv); - cv = vmaxq_f32(cv, zero); vst1q_f32(C_ptr, cv); c_ptr += 4; C_ptr += 4; - bias_ptr += 4; } if (_nc1 != 0) { cv = vld1q_f32(c_ptr); - biasv = vld1q_f32(bias_ptr); - cv = vmlaq_n_f32(nbias, cv, scale0); cv = vaddq_f32(cv, biasv); - cv = vmaxq_f32(cv, zero); if (_nc1 >= 1) { vst1q_lane_f32(C_ptr, cv, 0); C_ptr++; @@ -1182,314 +1444,322 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, C_ptr++; } if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - } - } - } -} - -void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, float *C, - int ldc, bool relu) {} - -#else - -void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { - const float *a_ptr, *b_ptr; - a_ptr = a; - b_ptr = b; - int kc1 = k / 4; - int kc2 = k % 4; - int step = 4 * ldc; - asm volatile( - "pld [%[a_ptr]] \n\t" - "pld [%[b_ptr]] \n\t" - "vmov.f32 q10, #0.0 \n\t" - "vmov.f32 q11, #0.0 \n\t" - "vmov.f32 q12, #0.0 \n\t" - "vmov.f32 q13, #0.0 \n\t" - - "subs %[kc1], %[kc1], #1 \n\t" - "blt end_kc1_%= \n\t" - "loop_kc1_%=: \n\t" - "pld [%[a_ptr], #64] \n\t" - "pld [%[b_ptr], #64] \n\t" - "vld1.32 {q0, q1}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" - "vmla.f32 q11, q2, d0[1] \n\t" - "vmla.f32 q12, q2, d1[0] \n\t" - "vmla.f32 q13, q2, d1[1] \n\t" - "vmla.f32 q10, q3, d2[0] \n\t" - "vmla.f32 q11, q3, d2[1] \n\t" - "vmla.f32 q12, q3, d3[0] \n\t" - "vmla.f32 q13, q3, d3[1] \n\t" - "vld1.32 {q4, q5}, [%[a_ptr]]! \n\t" - "vld1.32 {q6, q7}, [%[b_ptr]]! \n\t" - "vmla.f32 q10, q6, d8[0] \n\t" - "vmla.f32 q11, q6, d8[1] \n\t" - "vmla.f32 q12, q6, d9[0] \n\t" - "vmla.f32 q13, q6, d9[1] \n\t" - "vmla.f32 q10, q7, d10[0] \n\t" - "vmla.f32 q11, q7, d10[1] \n\t" - "vmla.f32 q12, q7, d11[0] \n\t" - "vmla.f32 q13, q7, d11[1] \n\t" - "subs %[kc1], %[kc1], #1 \n\t" - "bge loop_kc1_%= \n\t" - "end_kc1_%=: \n\t" - - "subs %[kc2], %[kc2], #1 \n\t" - "blt end_kc2_%= \n\t" - "loop_kc2_%=: \n\t" - "vld1.32 {q0}, [%[a_ptr]]! \n\t" - "vld1.32 {q1}, [%[b_ptr]]! \n\t" - "vmla.f32 q10, q1, d0[0] \n\t" - "vmla.f32 q11, q1, d0[1] \n\t" - "vmla.f32 q12, q1, d1[0] \n\t" - "vmla.f32 q13, q1, d1[1] \n\t" - "subs %[kc2], %[kc2], #1 \n\t" - "bge loop_kc2_%= \n\t" - "end_kc2_%=: \n\t" - - "mov r5, %[c] \n\t" - "mov r6, %[step] \n\t" - "vst1.32 {q10}, [r5], r6 \n\t" - "vst1.32 {q11}, [r5], r6 \n\t" - "vst1.32 {q12}, [r5], r6 \n\t" - "vst1.32 {q13}, [r5] \n\t" - : - : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [kc2] "r"(kc2), [step] "r"(step) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q10", "q11", "q12", "q13"); -} - -void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, float *C, - int ldc, bool relu) { - float *bufferC = static_cast(memory::Alloc(sizeof(float) * n)); - - const float *a0, *b0, *b1, *b2, *b3; - float *c0, *C0; - - int volatile kc1 = k / 4; - int volatile kc2 = k % 4; - int volatile nc1 = n / 16; - int _nc1 = n % 16; - int volatile nc2 = _nc1 / 4; - int volatile nc3 = _nc1 % 4; - for (int i = 0; i < kc1; i++) { - a0 = A + i * 4; - b0 = B + i * 4 * ldb; - b1 = b0 + ldb; - b2 = b1 + ldb; - b3 = b2 + ldb; - c0 = bufferC; - asm volatile( - "pld [%[a0], #16] \n\t" - "vld1.32 {q0}, [%[a0]] \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "cmp %[i], #0 \n\t" - "beq i_eq0_%= \n\t" - "bne i_ne0_%= \n\t" - - "i_eq0_%=: \n\t" - "vmov.f32 q10, #0.0 \n\t" - "vmov.f32 q11, #0.0 \n\t" - "vmov.f32 q12, #0.0 \n\t" - "vmov.f32 q13, #0.0 \n\t" - "b gemm_nc1_%= \n\t" - - "i_ne0_%=: \n\t" - "pld [%[c0], #64] \n\t" - "vld1.32 {q10, q11}, [%[c0]]! \n\t" - "vld1.32 {q12, q13}, [%[c0]] \n\t" - "sub %[c0], %[c0], #32 \n\t" - - "gemm_nc1_%=: \n\t" - "pld [%[b0], #64] \n\t" - "vld1.32 {q2, q3}, [%[b0]]! \n\t" - "vld1.32 {q4, q5}, [%[b0]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" - "vmla.f32 q11, q3, d0[0] \n\t" - "vmla.f32 q12, q4, d0[0] \n\t" - "vmla.f32 q13, q5, d0[0] \n\t" - - "pld [%[b1], #64] \n\t" - "vld1.32 {q2, q3}, [%[b1]]! \n\t" - "vld1.32 {q4, q5}, [%[b1]]! \n\t" - "vmla.f32 q10, q2, d0[1] \n\t" - "vmla.f32 q11, q3, d0[1] \n\t" - "vmla.f32 q12, q4, d0[1] \n\t" - "vmla.f32 q13, q5, d0[1] \n\t" - - "pld [%[b2], #64] \n\t" - "vld1.32 {q2, q3}, [%[b2]]! \n\t" - "vld1.32 {q4, q5}, [%[b2]]! \n\t" - "vmla.f32 q10, q2, d1[0] \n\t" - "vmla.f32 q11, q3, d1[0] \n\t" - "vmla.f32 q12, q4, d1[0] \n\t" - "vmla.f32 q13, q5, d1[0] \n\t" - - "pld [%[b3], #64] \n\t" - "vld1.32 {q2, q3}, [%[b3]]! \n\t" - "vld1.32 {q4, q5}, [%[b3]]! \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q4, d1[1] \n\t" - "vmla.f32 q13, q5, d1[1] \n\t" - - "vst1.32 {q10, q11}, [%[c0]]! \n\t" - "vst1.32 {q12, q13}, [%[c0]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "cmp %[i], #0 \n\t" - "beq ii_eq0_%= \n\t" - "bne ii_ne0_%= \n\t" - - "ii_eq0_%=: \n\t" - "vmov.f32 q10, #0.0 \n\t" - "b gemm_nc2_%= \n\t" - - "ii_ne0_%=: \n\t" - "pld [%[c0], #16] \n\t" - "vld1.32 {q10}, [%[c0]] \n\t" - - "gemm_nc2_%=: \n\t" - "pld [%[b0], #16] \n\t" - "vld1.32 {q2}, [%[b0]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" - - "pld [%[b1], #16] \n\t" - "vld1.32 {q3}, [%[b1]]! \n\t" - "vmla.f32 q10, q3, d0[1] \n\t" - - "pld [%[b2], #16] \n\t" - "vld1.32 {q4}, [%[b2]]! \n\t" - "vmla.f32 q10, q4, d1[0] \n\t" - - "pld [%[b3], #16] \n\t" - "vld1.32 {q5}, [%[b3]]! \n\t" - "vmla.f32 q10, q5, d1[1] \n\t" - - "vst1.32 {q10}, [%[c0]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), - [c0] "+r"(c0) - : [a0] "r"(a0), [i] "r"(i), [nc1] "r"(nc1), [nc2] "r"(nc2) - : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); - - for (int j = 0; j < nc3; j++) { - if (i == 0) { - *c0 = (*a0) * (*b0++); - } else { - *c0 += (*a0) * (*b0++); + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; } - *c0 += (*(a0 + 1)) * (*b1++); - *c0 += (*(a0 + 2)) * (*b2++); - *c0 += (*(a0 + 3)) * (*b3++); - c0++; } } +} - for (int i = 0; i < kc2; ++i) { - a0 = A + 4 * kc1 + i; - b0 = B + (4 * kc1 + i) * ldb; - c0 = bufferC; - asm volatile( - "pld [%[a0], #16] \n\t" - "vld1.32 {d0}, [%[a0]] \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "pld [%[c0], #64] \n\t" - "vld1.32 {q10, q11}, [%[c0]]! \n\t" - "vld1.32 {q12, q13}, [%[c0]] \n\t" - "sub %[c0], %[c0], #32 \n\t" +// C = A * B + C, relu(C) +void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { + int nc1 = nc / 4; + int _nc1 = nc % 4; - "gemm_nc1_%=: \n\t" - "pld [%[b0], #64] \n\t" - "vld1.32 {q2, q3}, [%[b0]]! \n\t" - "vld1.32 {q4, q5}, [%[b0]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" - "vmla.f32 q11, q3, d0[0] \n\t" - "vmla.f32 q12, q4, d0[0] \n\t" - "vmla.f32 q13, q5, d0[0] \n\t" + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t cv1; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv1 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv1); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } + } +} - "vst1.32 {q10, q11}, [%[c0]]! \n\t" - "vst1.32 {q12, q13}, [%[c0]]! \n\t" +// C = A * B + bias, relu(C) +void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" +// C = A * B + C,prelu(C) +void Gemm::WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, + float *p, std::string mode, float *bias, + float *bias1) { + int nc1 = nc / 4; + int _nc1 = nc % 4; - "pld [%[c0], #16] \n\t" - "vld1.32 {q10}, [%[c0]] \n\t" + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t cv1; + float32x4_t biasv; + float32x4_t biasv1; + float32x4_t zero = vdupq_n_f32(0.0); + float32x4_t pv; + float *ptr = p; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + if (bias1 == nullptr) { + biasv1 = zero; + } else { + biasv1 = vld1q_dup_f32(bias1 + i); + } - "gemm_nc2_%=: \n\t" - "vld1.32 {q2}, [%[b0]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vaddq_f32(cv, biasv1); + cv = vmaxq_f32(cv, zero); + cv1 = vminq_f32(cv, zero); + if (mode == "channel") { + cv1 = vmulq_n_f32(cv1, ptr[i]); + } else if (mode == "element") { + pv = vld1q_f32(ptr); + cv1 = vmulq_f32(cv1, pv); + ptr = ptr + 4; + } else { + cv1 = vmulq_n_f32(cv1, ptr[0]); + } + cv = vaddq_f32(cv, cv1); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + cv = vaddq_f32(cv, biasv1); + cv = vmaxq_f32(cv, zero); + cv1 = vminq_f32(cv, zero); + if (mode == "channel") { + cv1 = vmulq_n_f32(cv1, ptr[i]); + } else if (mode == "element") { + pv = vld1q_f32(ptr); + cv1 = vmulq_f32(cv1, pv); + ptr = ptr + 4; + } else { + cv1 = vmulq_n_f32(cv1, ptr[0]); + } + cv = vaddq_f32(cv, cv1); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} - "vst1.32 {q10}, [%[c0]]! \n\t" +// C = A * B, batchnorm(C) +void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t cv1; + float32x4_t bias; + float32x2_t scale; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; + } + } + } +} - : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), - [c0] "+r"(c0) - : [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2) - : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); +// C = A * B, batchnorm(C), relu(C) +void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; - for (int j = 0; j < nc3; j++) { - *c0 += (*a0) * (*b0++); - c0++; + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t bias; + float32x2_t scale; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vmlaq_n_f32(bias, cv, scale0); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } } } +} - if (alpha != 1) { - VecWriteWithAlphaBeta(n, bufferC, C, ldc); - return; - } - if (beta == 0) { - VecWriteBasic(n, bufferC, C, ldc); - return; - } - if (beta == 1 && !relu) { - VecWriteWithAdd(n, bufferC, C, ldc); - return; - } - if (beta == 1 && relu) { - VecWriteWithAddRelu(n, bufferC, C, ldc); - return; +// C = A * B, batchnorm(C),C = C + bias; relu(C) +void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias, float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr, *bias_ptr; + float32x4_t cv; + float32x4_t nbias; + float32x2_t scale; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias_ptr = bias + i * ldc; + nbias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + bias_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + } + } } } -/* -void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, +#else + +void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, - int ldc, bool relu, float *new_scale, float *new_bias) { + int ldc, bool relu) { float *bufferC = static_cast(memory::Alloc(sizeof(float) * n)); const float *a0, *b0, *b1, *b2, *b3; @@ -1642,931 +1912,546 @@ void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, "blt end_nc1_%= \n\t" "loop_nc1_%=: \n\t" - "pld [%[c0], #64] \n\t" - "vld1.32 {q10, q11}, [%[c0]]! \n\t" - "vld1.32 {q12, q13}, [%[c0]] \n\t" - "sub %[c0], %[c0], #32 \n\t" - - "gemm_nc1_%=: \n\t" - "pld [%[b0], #64] \n\t" - "vld1.32 {q2, q3}, [%[b0]]! \n\t" - "vld1.32 {q4, q5}, [%[b0]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" - "vmla.f32 q11, q3, d0[0] \n\t" - "vmla.f32 q12, q4, d0[0] \n\t" - "vmla.f32 q13, q5, d0[0] \n\t" - - "vst1.32 {q10, q11}, [%[c0]]! \n\t" - "vst1.32 {q12, q13}, [%[c0]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "pld [%[c0], #16] \n\t" - "vld1.32 {q10}, [%[c0]] \n\t" - - "gemm_nc2_%=: \n\t" - "vld1.32 {q2}, [%[b0]]! \n\t" - "vmla.f32 q10, q2, d0[0] \n\t" - - "vst1.32 {q10}, [%[c0]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), - [c0] "+r"(c0) - : [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2) - : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); - - for (int j = 0; j < nc3; j++) { - *c0 += (*a0) * (*b0++); - c0++; - } - } - - if (relu) { - VecWriteWithBnRelu(n, bufferC, C, ldc, new_scale, new_bias); - } else { - VecWriteWithBn(n, bufferC, C, ldc, new_scale, new_bias); - } -} -*/ - -void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { - const float *a_ptr, *b_ptr; - a_ptr = a; - b_ptr = b; - int kc1 = k / 4; - int kc2 = k % 4; - int step = 4 * ldc; - asm volatile( - "pld [%[a_ptr]] \n\t" - "pld [%[b_ptr]] \n\t" - - "vmov.f32 q8, #0.0 \n\t" - "vmov.f32 q9, #0.0 \n\t" - "vmov.f32 q10, #0.0 \n\t" - "vmov.f32 q11, #0.0 \n\t" - "vmov.f32 q12, #0.0 \n\t" - "vmov.f32 q13, #0.0 \n\t" - "vmov.f32 q14, #0.0 \n\t" - "vmov.f32 q15, #0.0 \n\t" - - "subs %[kc1], %[kc1], #1 \n\t" - "blt end_kc1_%= \n\t" - "loop_kc1_%=: \n\t" - - "pld [%[a_ptr], #64] \n\t" - "pld [%[b_ptr], #64] \n\t" - - "vld1.32 {q0, q1}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" - "vld1.32 {q4, q5}, [%[b_ptr]]! \n\t" - - "vmla.f32 q8, q2, d0[0] \n\t" - "vmla.f32 q9, q3, d0[0] \n\t" - "vmla.f32 q10, q2, d0[1] \n\t" - "vmla.f32 q11, q3, d0[1] \n\t" - "vmla.f32 q12, q2, d1[0] \n\t" - "vmla.f32 q13, q3, d1[0] \n\t" - "vmla.f32 q14, q2, d1[1] \n\t" - "vmla.f32 q15, q3, d1[1] \n\t" - - "vmla.f32 q8, q4, d2[0] \n\t" - "vmla.f32 q9, q5, d2[0] \n\t" - "vmla.f32 q10, q4, d2[1] \n\t" - "vmla.f32 q11, q5, d2[1] \n\t" - "vmla.f32 q12, q4, d3[0] \n\t" - "vmla.f32 q13, q5, d3[0] \n\t" - "vmla.f32 q14, q4, d3[1] \n\t" - "vmla.f32 q15, q5, d3[1] \n\t" - - "pld [%[b_ptr], #64] \n\t" - - "vld1.32 {q0, q1}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" - "vld1.32 {q4, q5}, [%[b_ptr]]! \n\t" - - "vmla.f32 q8, q2, d0[0] \n\t" - "vmla.f32 q9, q3, d0[0] \n\t" - "vmla.f32 q10, q2, d0[1] \n\t" - "vmla.f32 q11, q3, d0[1] \n\t" - "vmla.f32 q12, q2, d1[0] \n\t" - "vmla.f32 q13, q3, d1[0] \n\t" - "vmla.f32 q14, q2, d1[1] \n\t" - "vmla.f32 q15, q3, d1[1] \n\t" - - "vmla.f32 q8, q4, d2[0] \n\t" - "vmla.f32 q9, q5, d2[0] \n\t" - "vmla.f32 q10, q4, d2[1] \n\t" - "vmla.f32 q11, q5, d2[1] \n\t" - "vmla.f32 q12, q4, d3[0] \n\t" - "vmla.f32 q13, q5, d3[0] \n\t" - "vmla.f32 q14, q4, d3[1] \n\t" - "vmla.f32 q15, q5, d3[1] \n\t" - - "subs %[kc1], %[kc1], #1 \n\t" - "bge loop_kc1_%= \n\t" - "end_kc1_%=: \n\t" - - "subs %[kc2], %[kc2], #1 \n\t" - "blt end_kc2_%= \n\t" - "loop_kc2_%=: \n\t" - "vld1.32 {q0}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" - "vmla.f32 q8, q2, d0[0] \n\t" - "vmla.f32 q9, q3, d0[0] \n\t" - "vmla.f32 q10, q2, d0[1] \n\t" - "vmla.f32 q11, q3, d0[1] \n\t" - "vmla.f32 q12, q2, d1[0] \n\t" - "vmla.f32 q13, q3, d1[0] \n\t" - "vmla.f32 q14, q2, d1[1] \n\t" - "vmla.f32 q15, q3, d1[1] \n\t" - "subs %[kc2], %[kc2], #1 \n\t" - "bge loop_kc2_%= \n\t" - "end_kc2_%=: \n\t" + "pld [%[c0], #64] \n\t" + "vld1.32 {q10, q11}, [%[c0]]! \n\t" + "vld1.32 {q12, q13}, [%[c0]] \n\t" + "sub %[c0], %[c0], #32 \n\t" - "mov r5, %[c] \n\t" - "mov r6, %[step] \n\t" - "vst1.32 {q8, q9}, [r5], r6 \n\t" - "vst1.32 {q10, q11}, [r5], r6 \n\t" - "vst1.32 {q12, q13}, [r5], r6 \n\t" - "vst1.32 {q14, q15}, [r5] \n\t" - : - : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [kc2] "r"(kc2), [step] "r"(step) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q8", "q9", - "q10", "q11", "q12", "q13", "q14", "q15"); -} + "gemm_nc1_%=: \n\t" + "pld [%[b0], #64] \n\t" + "vld1.32 {q2, q3}, [%[b0]]! \n\t" + "vld1.32 {q4, q5}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q3, d0[0] \n\t" + "vmla.f32 q12, q4, d0[0] \n\t" + "vmla.f32 q13, q5, d0[0] \n\t" -// C = A * B -void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { - int nc1 = nc / 16; - int _nc1 = nc % 16; - int step = 4 * ldc; - int step1 = 4 * (NC - 16 * nc1); - int volatile m = mc; + "vst1.32 {q10, q11}, [%[c0]]! \n\t" + "vst1.32 {q12, q13}, [%[c0]]! \n\t" - float *volatile c_ptr, *volatile C_ptr; - float *C0, *c0; - c_ptr = c; - C_ptr = C; - if (nc1 > 0) { - asm volatile( - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "mov r6, %[C_ptr] \n\t" - "mov r5, %[nc1] \n\t" - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" - "vst1.32 {q0, q1}, [r6]! \n\t" + "pld [%[c0], #16] \n\t" + "vld1.32 {q10}, [%[c0]] \n\t" - "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" - "vst1.32 {q2, q3}, [r6]! \n\t" + "gemm_nc2_%=: \n\t" + "vld1.32 {q2}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "vst1.32 {q10}, [%[c0]]! \n\t" - "add %[C_ptr], %[C_ptr], %[step] \n\t" - "add %[c_ptr], %[c_ptr], %[step1] \n\t" - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), - [step] "r"(step), [step1] "r"(step1) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); - } + : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), + [c0] "+r"(c0) + : [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2) + : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); - if (_nc1 != 0) { - for (int i = 0; i < mc; i++) { - C0 = C_ptr + nc1 * 16 + i * ldc; - c0 = c_ptr + nc1 * 16 + i * NC; - for (int j = 0; j < _nc1; j++) { - *C0++ = *c0++; - } + for (int j = 0; j < nc3; j++) { + *c0 += (*a0) * (*b0++); + c0++; } } + + if (alpha != 1) { + VecWriteWithAlphaBeta(n, bufferC, C, ldc); + return; + } + if (beta == 0) { + VecWriteBasic(n, bufferC, C, ldc); + return; + } + if (beta == 1 && !relu) { + VecWriteWithAdd(n, bufferC, C, ldc); + return; + } + if (beta == 1 && relu) { + VecWriteWithAddRelu(n, bufferC, C, ldc); + return; + } } -// C = alpha * A * B + beta * C -void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} +void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, + float *C, int ldc, bool relu, float *new_scale, + float *new_bias) { + float *bufferC = static_cast(memory::Alloc(sizeof(float) * n)); -// C = A * B + C -void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { - int nc1 = nc / 16; - int _nc1 = nc % 16; - int step = 4 * ldc; - int step1 = 4 * (NC - 16 * nc1); - int volatile m = mc; + const float *a0, *b0, *b1, *b2, *b3; + float *c0, *C0; - float *volatile c_ptr, *volatile C_ptr; - float *C0, *c0; - c_ptr = c; - C_ptr = C; - if (nc1 > 0) { + int volatile kc1 = k / 4; + int volatile kc2 = k % 4; + int volatile nc1 = n / 16; + int _nc1 = n % 16; + int volatile nc2 = _nc1 / 4; + int volatile nc3 = _nc1 % 4; + for (int i = 0; i < kc1; i++) { + a0 = A + i * 4; + b0 = B + i * 4 * ldb; + b1 = b0 + ldb; + b2 = b1 + ldb; + b3 = b2 + ldb; + c0 = bufferC; asm volatile( - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" - - "mov r6, %[C_ptr] \n\t" - "mov r5, %[nc1] \n\t" - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [r6] \n\t" - "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [r6]! \n\t" - - "vld1.32 {q4, q5}, [r6] \n\t" - "vld1.32 {q6, q7}, [%[c_ptr]]! \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [r6]! \n\t" + "pld [%[a0], #16] \n\t" + "vld1.32 {q0}, [%[a0]] \n\t" - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "add %[C_ptr], %[C_ptr], %[step] \n\t" - "add %[c_ptr], %[c_ptr], %[step1] \n\t" - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" + "cmp %[i], #0 \n\t" + "beq i_eq0_%= \n\t" + "bne i_ne0_%= \n\t" - : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), - [step] "r"(step), [step1] "r"(step1) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q10", "q11", "q12", "q13"); - } + "i_eq0_%=: \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "b gemm_nc1_%= \n\t" - if (_nc1 != 0) { - for (int i = 0; i < mc; i++) { - C0 = C_ptr + nc1 * 16 + i * ldc; - c0 = c_ptr + nc1 * 16 + i * NC; - for (int j = 0; j < _nc1; j++) { - *C0++ += *c0++; - } - } - } -} + "i_ne0_%=: \n\t" + "pld [%[c0], #64] \n\t" + "vld1.32 {q10, q11}, [%[c0]]! \n\t" + "vld1.32 {q12, q13}, [%[c0]] \n\t" + "sub %[c0], %[c0], #32 \n\t" -// C = A * B + bias -void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, - float *bias) { - int nc1 = nc / 4; - int _nc1 = nc % 4; + "gemm_nc1_%=: \n\t" + "pld [%[b0], #64] \n\t" + "vld1.32 {q2, q3}, [%[b0]]! \n\t" + "vld1.32 {q4, q5}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q3, d0[0] \n\t" + "vmla.f32 q12, q4, d0[0] \n\t" + "vmla.f32 q13, q5, d0[0] \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t biasv; - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - biasv = vld1q_dup_f32(bias + i); - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - C_ptr++; - } - } - } -} + "pld [%[b1], #64] \n\t" + "vld1.32 {q2, q3}, [%[b1]]! \n\t" + "vld1.32 {q4, q5}, [%[b1]]! \n\t" + "vmla.f32 q10, q2, d0[1] \n\t" + "vmla.f32 q11, q3, d0[1] \n\t" + "vmla.f32 q12, q4, d0[1] \n\t" + "vmla.f32 q13, q5, d0[1] \n\t" -// C = A * B + C, relu(C) -void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { - int nc1 = nc / 16; - int _nc1 = nc % 16; - int step = 4 * ldc; - int step1 = 4 * (NC - 16 * nc1); - int volatile m = mc; + "pld [%[b2], #64] \n\t" + "vld1.32 {q2, q3}, [%[b2]]! \n\t" + "vld1.32 {q4, q5}, [%[b2]]! \n\t" + "vmla.f32 q10, q2, d1[0] \n\t" + "vmla.f32 q11, q3, d1[0] \n\t" + "vmla.f32 q12, q4, d1[0] \n\t" + "vmla.f32 q13, q5, d1[0] \n\t" - float *volatile c_ptr, *volatile C_ptr; - float *C0, *c0; - c_ptr = c; - C_ptr = C; - if (nc1 > 0) { - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" + "pld [%[b3], #64] \n\t" + "vld1.32 {q2, q3}, [%[b3]]! \n\t" + "vld1.32 {q4, q5}, [%[b3]]! \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q4, d1[1] \n\t" + "vmla.f32 q13, q5, d1[1] \n\t" - "mov r6, %[C_ptr] \n\t" - "mov r5, %[nc1] \n\t" - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + "vst1.32 {q10, q11}, [%[c0]]! \n\t" + "vst1.32 {q12, q13}, [%[c0]]! \n\t" - "vld1.32 {q0, q1}, [r6] \n\t" - "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [r6]! \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "vld1.32 {q4, q5}, [r6] \n\t" - "vld1.32 {q6, q7}, [%[c_ptr]]! \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [r6]! \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "cmp %[i], #0 \n\t" + "beq ii_eq0_%= \n\t" + "bne ii_ne0_%= \n\t" - "add %[C_ptr], %[C_ptr], %[step] \n\t" - "add %[c_ptr], %[c_ptr], %[step1] \n\t" - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" + "ii_eq0_%=: \n\t" + "vmov.f32 q10, #0.0 \n\t" + "b gemm_nc2_%= \n\t" - : - : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), - [step] "r"(step), [step1] "r"(step1) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q10", "q11", "q12", "q13"); - } + "ii_ne0_%=: \n\t" + "pld [%[c0], #16] \n\t" + "vld1.32 {q10}, [%[c0]] \n\t" - if (_nc1 != 0) { - for (int i = 0; i < mc; i++) { - C0 = C_ptr + nc1 * 16 + i * ldc; - c0 = c_ptr + nc1 * 16 + i * NC; - for (int j = 0; j < _nc1; j++) { - *C0 += *c0; - if (*C0 < 0) { - *C0 = 0; - } - C0++; - c0++; - } - } - } -} + "gemm_nc2_%=: \n\t" + "pld [%[b0], #16] \n\t" + "vld1.32 {q2}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" -// C = A * B + bias, relu(C) -void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, - float *bias) { - int nc1 = nc / 4; - int _nc1 = nc % 4; + "pld [%[b1], #16] \n\t" + "vld1.32 {q3}, [%[b1]]! \n\t" + "vmla.f32 q10, q3, d0[1] \n\t" - float *c_ptr, *C_ptr; - float32x4_t cv; - float32x4_t biasv; - float32x4_t zero = vdupq_n_f32(0.0); - for (int i = 0; i < mc; ++i) { - c_ptr = c + i * NC; - C_ptr = C + i * ldc; - biasv = vld1q_dup_f32(bias + i); - for (int j = 0; j < nc1; ++j) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - cv = vmaxq_f32(cv, zero); - vst1q_f32(C_ptr, cv); - c_ptr += 4; - C_ptr += 4; - } - if (_nc1 != 0) { - cv = vld1q_f32(c_ptr); - cv = vaddq_f32(cv, biasv); - cv = vmaxq_f32(cv, zero); - if (_nc1 >= 1) { - vst1q_lane_f32(C_ptr, cv, 0); - C_ptr++; - } - if (_nc1 >= 2) { - vst1q_lane_f32(C_ptr, cv, 1); - C_ptr++; - } - if (_nc1 >= 3) { - vst1q_lane_f32(C_ptr, cv, 2); - C_ptr++; - } - } - } -} + "pld [%[b2], #16] \n\t" + "vld1.32 {q4}, [%[b2]]! \n\t" + "vmla.f32 q10, q4, d1[0] \n\t" -void Gemm::WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, - float *p, std::string mode, float *bias, - float *bias1) { - if (nc < 4) { - if (bias1 == nullptr) { - for (int i = 0; i < mc; ++i) { - for (int j = 0; j < nc; ++j) { - float r = c[i * NC + j] + bias[i]; - if (r < 0) { - r *= p[i]; - } - C[i * ldc + j] = r; - } - } - } else { - for (int i = 0; i < mc; ++i) { - for (int j = 0; j < nc; ++j) { - float r = c[i * NC + j] + bias[i]; - r += bias1[i * ldc + j]; - if (r < 0) { - r *= p[i]; - } - C[i * ldc + j] = r; - } - } - } - return; - } + "pld [%[b3], #16] \n\t" + "vld1.32 {q5}, [%[b3]]! \n\t" + "vmla.f32 q10, q5, d1[1] \n\t" - int nc1 = nc / 16; - int _nc1 = nc % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); - int step = 4 * (ldc - nc); - int step1 = 4 * (NC - nc); + "vst1.32 {q10}, [%[c0]]! \n\t" - if (bias1 == nullptr) { - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - "mov r5, %[nc1] \n\t" - "mov r6, %[nc2] \n\t" - "vld1.32 {d0}, [%[bias]] \n\t" - "vld1.32 {d1}, [%[p]] \n\t" - "vdup.32 q1, d0[0] \n\t" - "vdup.32 q2, d1[0] \n\t" + : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), + [c0] "+r"(c0) + : [a0] "r"(a0), [i] "r"(i), [nc1] "r"(nc1), [nc2] "r"(nc2) + : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + for (int j = 0; j < nc3; j++) { + if (i == 0) { + *c0 = (*a0) * (*b0++); + } else { + *c0 += (*a0) * (*b0++); + } + *c0 += (*(a0 + 1)) * (*b1++); + *c0 += (*(a0 + 2)) * (*b2++); + *c0 += (*(a0 + 3)) * (*b3++); + c0++; + } + } - "pld [%[c], #32] \n\t" - "vld1.32 {q3, q4}, [%[c]]! \n\t" - "vld1.32 {q9, q10}, [%[c]]! \n\t" + for (int i = 0; i < kc2; ++i) { + a0 = A + 4 * kc1 + i; + b0 = B + (4 * kc1 + i) * ldb; + c0 = bufferC; + asm volatile( + "pld [%[a0], #16] \n\t" + "vld1.32 {d0}, [%[a0]] \n\t" - "vadd.f32 q3, q3, q1 \n\t" - "vadd.f32 q4, q4, q1 \n\t" - "vadd.f32 q9, q9, q1 \n\t" - "vadd.f32 q10, q10, q1 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vmax.f32 q5, q3, q14 \n\t" - "vmin.f32 q7, q3, q14 \n\t" - "vmax.f32 q6, q4, q14 \n\t" - "vmin.f32 q8, q4, q14 \n\t" + "pld [%[c0], #64] \n\t" + "vld1.32 {q10, q11}, [%[c0]]! \n\t" + "vld1.32 {q12, q13}, [%[c0]] \n\t" + "sub %[c0], %[c0], #32 \n\t" - "vmax.f32 q11, q9, q14 \n\t" - "vmin.f32 q13, q9, q14 \n\t" - "vmax.f32 q12, q10, q14 \n\t" - "vmin.f32 q15, q10, q14 \n\t" + "gemm_nc1_%=: \n\t" + "pld [%[b0], #64] \n\t" + "vld1.32 {q2, q3}, [%[b0]]! \n\t" + "vld1.32 {q4, q5}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" + "vmla.f32 q11, q3, d0[0] \n\t" + "vmla.f32 q12, q4, d0[0] \n\t" + "vmla.f32 q13, q5, d0[0] \n\t" - "vmla.f32 q5, q7, q2 \n\t" - "vmla.f32 q6, q8, q2 \n\t" - "vmla.f32 q11, q13, q2 \n\t" - "vmla.f32 q12, q15, q2 \n\t" + "vst1.32 {q10, q11}, [%[c0]]! \n\t" + "vst1.32 {q12, q13}, [%[c0]]! \n\t" - "vst1.32 {q5, q6}, [%[C]]! \n\t" - "vst1.32 {q11, q12}, [%[C]]! \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - "subs r6, r6, #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" + "pld [%[c0], #16] \n\t" + "vld1.32 {q10}, [%[c0]] \n\t" - "vld1.32 {q3}, [%[c]]! \n\t" - "vadd.f32 q3, q3, q1 \n\t" - "vmax.f32 q5, q3, q14 \n\t" - "vmin.f32 q7, q3, q14 \n\t" - "vmla.f32 q5, q7, q2 \n\t" - "vst1.32 {q5}, [%[C]]! \n\t" + "gemm_nc2_%=: \n\t" + "vld1.32 {q2}, [%[b0]]! \n\t" + "vmla.f32 q10, q2, d0[0] \n\t" - "subs r6, r6, #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" + "vst1.32 {q10}, [%[c0]]! \n\t" - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" + : [b0] "+r"(b0), [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3), + [c0] "+r"(c0) + : [a0] "r"(a0), [nc1] "r"(nc1), [nc2] "r"(nc2) + : "memory", "q0", "q2", "q3", "q4", "q5", "q10", "q11", "q12", "q13"); - "vld1.32 {q4}, [%[c]]! \n\t" - "vadd.f32 q4, q4, q1 \n\t" - "vmax.f32 q6, q4, q14 \n\t" - "vmin.f32 q8, q4, q14 \n\t" - "vmla.f32 q6, q8, q2 \n\t" - "vst1.32 {q6}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + for (int j = 0; j < nc3; j++) { + *c0 += (*a0) * (*b0++); + c0++; + } + } - "add %[p], %[p], #4 \n\t" - "add %[bias], %[bias], #4 \n\t" - "add %[c], %[c], %[step1] \n\t" - "add %[C], %[C], %[step] \n\t" + if (relu) { + VecWriteWithBnRelu(n, bufferC, C, ldc, new_scale, new_bias); + } else { + VecWriteWithBn(n, bufferC, C, ldc, new_scale, new_bias); + } +} - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" +// C = A * B +void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) { + int nc1 = nc / 16; + int _nc1 = nc % 16; + int step = 4 * ldc; + int step1 = 4 * (NC - 16 * nc1); + int volatile m = mc; - : - : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), - [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [p] "r"(p), - [bias] "r"(bias), [bias1] "r"(bias1) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8"); - } else { + float *volatile c_ptr, *volatile C_ptr; + float *C0, *c0; + c_ptr = c; + C_ptr = C; + if (nc1 > 0) { asm volatile( - "vmov.f32 q14, #0.0 \n\t" "subs %[mc], %[mc], #1 \n\t" "blt end_mc_%= \n\t" "loop_mc_%=: \n\t" - "mov r5, %[nc1] \n\t" - "mov r6, %[nc2] \n\t" - "vld1.32 {d0}, [%[bias]] \n\t" - "vld1.32 {d1}, [%[p]] \n\t" - "vdup.32 q1, d0[0] \n\t" - "vdup.32 q2, d1[0] \n\t" - + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" "subs r5, r5, #1 \n\t" "blt end_nc1_%= \n\t" "loop_nc1_%=: \n\t" - "pld [%[c], #32] \n\t" - "pld [%[bias1], #32] \n\t" - "vld1.32 {q3, q4}, [%[c]]! \n\t" - "vld1.32 {q9, q10}, [%[bias1]]! \n\t" - "vadd.f32 q3, q3, q1 \n\t" - "vadd.f32 q4, q4, q1 \n\t" - "vadd.f32 q3, q3, q9 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vmax.f32 q5, q3, q14 \n\t" - "vmin.f32 q7, q3, q14 \n\t" - "vmax.f32 q6, q4, q14 \n\t" - "vmin.f32 q8, q4, q14 \n\t" - "vmla.f32 q5, q7, q2 \n\t" - "vmla.f32 q6, q8, q2 \n\t" - "vst1.32 {q5, q6}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" - "vld1.32 {q3, q4}, [%[c]]! \n\t" - "vld1.32 {q9, q10}, [%[bias1]]! \n\t" - "vadd.f32 q3, q3, q1 \n\t" - "vadd.f32 q4, q4, q1 \n\t" - "vadd.f32 q3, q3, q9 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vmax.f32 q5, q3, q14 \n\t" - "vmin.f32 q7, q3, q14 \n\t" - "vmax.f32 q6, q4, q14 \n\t" - "vmin.f32 q8, q4, q14 \n\t" - "vmla.f32 q5, q7, q2 \n\t" - "vmla.f32 q6, q8, q2 \n\t" - "vst1.32 {q5, q6}, [%[C]]! \n\t" + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vst1.32 {q2, q3}, [r6]! \n\t" "subs r5, r5, #1 \n\t" "bge loop_nc1_%= \n\t" "end_nc1_%=: \n\t" - "subs r6, r6, #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q3}, [%[c]]! \n\t" - "vld1.32 {q9}, [%[bias1]]! \n\t" - "vadd.f32 q3, q3, q1 \n\t" - "vadd.f32 q3, q3, q9 \n\t" - "vmax.f32 q5, q3, q14 \n\t" - "vmin.f32 q7, q3, q14 \n\t" - "vmla.f32 q5, q7, q2 \n\t" - "vst1.32 {q5}, [%[C]]! \n\t" - - "subs r6, r6, #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - "sub %[bias1], %[bias1], %[nc3] \n\t" - - "vld1.32 {q4}, [%[c]]! \n\t" - "vld1.32 {q10}, [%[bias1]]! \n\t" - "vadd.f32 q4, q4, q1 \n\t" - "vadd.f32 q4, q4, q10 \n\t" - "vmax.f32 q6, q4, q14 \n\t" - "vmin.f32 q8, q4, q14 \n\t" - "vmla.f32 q6, q8, q2 \n\t" - "vst1.32 {q6}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" - - "add %[p], %[p], #4 \n\t" - "add %[bias], %[bias], #4 \n\t" - "add %[c], %[c], %[step1] \n\t" - "add %[C], %[C], %[step] \n\t" - "add %[bias1], %[bias1], %[step] \n\t" - + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" "subs %[mc], %[mc], #1 \n\t" "bge loop_mc_%= \n\t" "end_mc_%=: \n\t" : - : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), - [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [p] "r"(p), - [bias] "r"(bias), [bias1] "r"(bias1) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10"); + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); } -} -// C = A * B, batchnorm(C) -void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, - float *scale, float *bias) { - if (nc < 4) { - for (int i = 0; i < mc; ++i) { - for (int j = 0; j < nc; ++j) { - *C = (*c) * (*scale) + (*bias); - C++; - c++; + if (_nc1 != 0) { + for (int i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 16 + i * ldc; + c0 = c_ptr + nc1 * 16 + i * NC; + for (int j = 0; j < _nc1; j++) { + *C0++ = *c0++; } - C += (ldc - nc); - c += (NC - nc); - scale++; - bias++; } - return; } +} - int volatile nc1 = nc / 16; - int _nc1 = nc % 16; - int volatile nc2 = _nc1 / 4; - int volatile nc3 = 16 - 4 * (_nc1 % 4); - int volatile step = 4 * (ldc - nc); - int volatile step1 = 4 * (NC - nc); - - asm volatile( - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" - - "mov r5, %[nc1] \n\t" - "mov r6, %[nc2] \n\t" - "vld1.32 {d0}, [%[scale]] \n\t" - "vld1.32 {d1}, [%[bias]] \n\t" - "vdup.32 q1, d0[0] \n\t" - "vdup.32 q2, d1[0] \n\t" - - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q3, q4}, [%[c]]! \n\t" - "vmul.f32 q10, q3, q1 \n\t" - "vmul.f32 q11, q4, q1 \n\t" - "vadd.f32 q10, q10, q2 \n\t" - "vadd.f32 q11, q11, q2 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q5, q6}, [%[c]]! \n\t" - "vmul.f32 q12, q5, q1 \n\t" - "vmul.f32 q13, q6, q1 \n\t" - "vadd.f32 q12, q12, q2 \n\t" - "vadd.f32 q13, q13, q2 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" +// C = alpha * A * B + beta * C +void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} - "subs r6, r6, #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" +// C = A * B + C +void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) { + int nc1 = nc / 16; + int _nc1 = nc % 16; + int step = 4 * ldc; + int step1 = 4 * (NC - 16 * nc1); + int volatile m = mc; - "vld1.32 {q7}, [%[c]]! \n\t" - "vmul.f32 q10, q7, q1 \n\t" - "vadd.f32 q10, q10, q2 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + float *volatile c_ptr, *volatile C_ptr; + float *C0, *c0; + c_ptr = c; + C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" - "subs r6, r6, #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" + "vld1.32 {q0, q1}, [r6] \n\t" + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [r6]! \n\t" - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" + "vld1.32 {q4, q5}, [r6] \n\t" + "vld1.32 {q6, q7}, [%[c_ptr]]! \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [r6]! \n\t" - "vld1.32 {q8}, [%[c]]! \n\t" - "vmul.f32 q11, q8, q1 \n\t" - "vadd.f32 q11, q11, q2 \n\t" - "vst1.32 {q11}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "add %[scale], %[scale], #4 \n\t" - "add %[bias], %[bias], #4 \n\t" - "add %[c], %[c], %[step1] \n\t" - "add %[C], %[C], %[step] \n\t" + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q10", "q11", "q12", "q13"); + } - : - : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), - [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), - [scale] "r"(scale), [bias] "r"(bias) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q10", "q11", "q12", "q13"); + if (_nc1 != 0) { + for (int i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 16 + i * ldc; + c0 = c_ptr + nc1 * 16 + i * NC; + for (int j = 0; j < _nc1; j++) { + *C0++ += *c0++; + } + } + } } -// C = A * B, batchnorm(C), relu(C) -void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, - float *scale, float *bias) { - if (nc < 4) { - for (int i = 0; i < mc; ++i) { - for (int j = 0; j < nc; ++j) { - *C = (*c) * (*scale) + (*bias); - if (*C < 0) { - *C = 0; - } - C++; - c++; +// C = A * B + bias +void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float *c_ptr, *C_ptr; + float32x4_t cv; + float32x4_t biasv; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + biasv = vld1q_dup_f32(bias + i); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vaddq_f32(cv, biasv); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; } - C += (ldc - nc); - c += (NC - nc); - scale++; - bias++; } - return; } +} +// C = A * B + C, relu(C) +void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) { int nc1 = nc / 16; int _nc1 = nc % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); - int step = 4 * (ldc - nc); - int step1 = 4 * (NC - nc); - - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[mc], %[mc], #1 \n\t" - "blt end_mc_%= \n\t" - "loop_mc_%=: \n\t" - - "mov r5, %[nc1] \n\t" - "mov r6, %[nc2] \n\t" - "vld1.32 {d0}, [%[scale]] \n\t" - "vld1.32 {d1}, [%[bias]] \n\t" - "vdup.32 q1, d0[0] \n\t" - "vdup.32 q2, d1[0] \n\t" - - "subs r5, r5, #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q3, q4}, [%[c]]! \n\t" - "vmul.f32 q10, q3, q1 \n\t" - "vmul.f32 q11, q4, q1 \n\t" - "vadd.f32 q10, q10, q2 \n\t" - "vadd.f32 q11, q11, q2 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q5, q6}, [%[c]]! \n\t" - "vmul.f32 q12, q5, q1 \n\t" - "vmul.f32 q13, q6, q1 \n\t" - "vadd.f32 q12, q12, q2 \n\t" - "vadd.f32 q13, q13, q2 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs r5, r5, #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs r6, r6, #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" + int step = 4 * ldc; + int step1 = 4 * (NC - 16 * nc1); + int volatile m = mc; - "vld1.32 {q7}, [%[c]]! \n\t" - "vmul.f32 q10, q7, q1 \n\t" - "vadd.f32 q10, q10, q2 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" + float *volatile c_ptr, *volatile C_ptr; + float *C0, *c0; + c_ptr = c; + C_ptr = C; + if (nc1 > 0) { + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" - "subs r6, r6, #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" + "vld1.32 {q0, q1}, [r6] \n\t" + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [r6]! \n\t" - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" + "vld1.32 {q4, q5}, [r6] \n\t" + "vld1.32 {q6, q7}, [%[c_ptr]]! \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [r6]! \n\t" - "vld1.32 {q8}, [%[c]]! \n\t" - "vmul.f32 q11, q8, q1 \n\t" - "vadd.f32 q11, q11, q2 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q11}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "add %[scale], %[scale], #4 \n\t" - "add %[bias], %[bias], #4 \n\t" - "add %[c], %[c], %[step1] \n\t" - "add %[C], %[C], %[step] \n\t" + "add %[C_ptr], %[C_ptr], %[step] \n\t" + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" - "subs %[mc], %[mc], #1 \n\t" - "bge loop_mc_%= \n\t" - "end_mc_%=: \n\t" + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step] "r"(step), [step1] "r"(step1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q10", "q11", "q12", "q13"); + } - : - : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), - [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), - [scale] "r"(scale), [bias] "r"(bias) - : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q10", "q11", "q12", "q13", "q14"); + if (_nc1 != 0) { + for (int i = 0; i < mc; i++) { + C0 = C_ptr + nc1 * 16 + i * ldc; + c0 = c_ptr + nc1 * 16 + i * NC; + for (int j = 0; j < _nc1; j++) { + *C0 += *c0; + if (*C0 < 0) { + *C0 = 0; + } + C0++; + c0++; + } + } + } } -// C = A * B, batchnorm(C),C = C + bias; relu(C) -void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias, float *bias) { +// C = A * B + bias, relu(C) +void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, + float *bias) { int nc1 = nc / 4; int _nc1 = nc % 4; - float *c_ptr, *C_ptr, *bias_ptr; + float *c_ptr, *C_ptr; float32x4_t cv; - float32x4_t nbias; - float32x2_t scale; float32x4_t biasv; float32x4_t zero = vdupq_n_f32(0.0); for (int i = 0; i < mc; ++i) { c_ptr = c + i * NC; C_ptr = C + i * ldc; - bias_ptr = bias + i * ldc; - nbias = vld1q_dup_f32(new_bias); - scale = vld1_dup_f32(new_scale); - new_bias++; - new_scale++; - float scale0 = vget_lane_f32(scale, 0); + biasv = vld1q_dup_f32(bias + i); for (int j = 0; j < nc1; ++j) { cv = vld1q_f32(c_ptr); - biasv = vld1q_f32(bias_ptr); - cv = vmlaq_n_f32(nbias, cv, scale0); cv = vaddq_f32(cv, biasv); cv = vmaxq_f32(cv, zero); vst1q_f32(C_ptr, cv); c_ptr += 4; C_ptr += 4; - bias_ptr += 4; } if (_nc1 != 0) { cv = vld1q_f32(c_ptr); - biasv = vld1q_f32(bias_ptr); - cv = vmlaq_n_f32(nbias, cv, scale0); cv = vaddq_f32(cv, biasv); cv = vmaxq_f32(cv, zero); if (_nc1 >= 1) { @@ -2579,1433 +2464,1395 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, } if (_nc1 >= 3) { vst1q_lane_f32(C_ptr, cv, 2); + C_ptr++; } } } } -// C = A * B -void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; +void Gemm::WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, + float *p, std::string mode, float *bias, + float *bias1) { + if (nc < 4) { + if (bias1 == nullptr) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + float r = c[i * NC + j] + bias[i]; + if (r < 0) { + r *= p[i]; + } + C[i * ldc + j] = r; + } + } + } else { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + float r = c[i * NC + j] + bias[i]; + r += bias1[i * ldc + j]; + if (r < 0) { + r *= p[i]; + } + C[i * ldc + j] = r; + } + } + } + return; + } + + int nc1 = nc / 16; + int _nc1 = nc % 16; int nc2 = _nc1 / 4; int nc3 = 16 - 4 * (_nc1 % 4); + int step = 4 * (ldc - nc); + int step1 = 4 * (NC - nc); - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vst1.32 {q0, q1}, [%[C]]! \n\t" - - "vld1.32 {q2, q3}, [%[c]]! \n\t" - "vst1.32 {q2, q3}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q4}, [%[c]]! \n\t" - "vst1.32 {q4}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q5}, [%[c]]! \n\t" - "vst1.32 {q5}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" - - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); -} - -// C = alpha * A * B + beta * C -void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} - -// C = A * B + C -void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; - - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[C]] \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[C]] \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - : [C] "+r"(C), [c] "+r"(c) - : [nc1] "r"(nc1) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", - "q12", "q13"); - - if (_nc1 != 0) { - for (int j = 0; j < _nc1; j++) { - *C++ += *c++; - } - } -} + if (bias1 == nullptr) { + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" -// C = A * B + C, relu(C) -void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; + "mov r5, %[nc1] \n\t" + "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[bias]] \n\t" + "vld1.32 {d1}, [%[p]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[C]] \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" + "pld [%[c], #32] \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vld1.32 {q9, q10}, [%[c]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[C]] \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q9, q9, q1 \n\t" + "vadd.f32 q10, q10, q1 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" - : [C] "+r"(C), [c] "+r"(c) - : [nc1] "r"(nc1) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", - "q12", "q13"); + "vmax.f32 q11, q9, q14 \n\t" + "vmin.f32 q13, q9, q14 \n\t" + "vmax.f32 q12, q10, q14 \n\t" + "vmin.f32 q15, q10, q14 \n\t" - if (_nc1 != 0) { - for (int j = 0; j < _nc1; j++) { - *C += *c; - if (*C < 0) { - *C = 0; - } - C++; - c++; - } - } -} + "vmla.f32 q5, q7, q2 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vmla.f32 q11, q13, q2 \n\t" + "vmla.f32 q12, q15, q2 \n\t" - /* - // C = A * B, batchnorm(C) - void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, - float *bias) { - int nc1 = n / 16; - int _nc1 = n % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); + "vst1.32 {q5, q6}, [%[C]]! \n\t" + "vst1.32 {q11, q12}, [%[C]]! \n\t" - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" - - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] - "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", - "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13"); - } - - // C = A * B, batchnorm(C), relu(C) - void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float - *scale, float *bias) { int nc1 = n / 16; int _nc1 = n % 16; int nc2 = _nc1 / - 4; int nc3 = 16 - 4 * (_nc1 % 4); + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" - - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] - "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", - "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14"); - } - */ + "subs r6, r6, #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" -#endif // __aarch64__ -#else + "vld1.32 {q3}, [%[c]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" -void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { - float *c0, *c1, *c2, *c3; - c0 = c; - c1 = c + ldc; - c2 = c + 2 * ldc; - c3 = c + 3 * ldc; - for (int p = 0; p < k; p += 1) { - // first row - c0[0] += a[0] * b[0]; - c0[1] += a[0] * b[1]; - c0[2] += a[0] * b[2]; - c0[3] += a[0] * b[3]; - - // second row - c1[0] += a[1] * b[0]; - c1[1] += a[1] * b[1]; - c1[2] += a[1] * b[2]; - c1[3] += a[1] * b[3]; - - // third row - c2[0] += a[2] * b[0]; - c2[1] += a[2] * b[1]; - c2[2] += a[2] * b[2]; - c2[3] += a[2] * b[3]; - - // fourth row - c3[0] += a[3] * b[0]; - c3[1] += a[3] * b[1]; - c3[2] += a[3] * b[2]; - c3[3] += a[3] * b[3]; - - a += 4; - b += 4; - } -} + "subs r6, r6, #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" -void Gemm::AddDot4x8(int k, const float *a, const float *b, float *c, int ldc) { -} + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" -void Gemm::WriteBasic(int mc, int nc, float *c, float *C, int ldc) {} + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" -void Gemm::WriteWithAlphaBeta(int mc, int nc, float *c, float *C, int ldc) {} + "vld1.32 {q4}, [%[c]]! \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q6}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" -void Gemm::WriteWithAdd(int mc, int nc, float *c, float *C, int ldc) {} + "add %[p], %[p], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" + "add %[c], %[c], %[step1] \n\t" + "add %[C], %[C], %[step] \n\t" -void Gemm::WriteWithAddV1(int mc, int nc, float *c, float *C, int ldc, - float *bias) {} + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" -void Gemm::WriteWithAddRelu(int mc, int nc, float *c, float *C, int ldc) {} + : + : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), + [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [p] "r"(p), + [bias] "r"(bias), [bias1] "r"(bias1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8"); + } else { + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" -void Gemm::WriteWithAddReluV1(int mc, int nc, float *c, float *C, int ldc, - float *bias) {} + "mov r5, %[nc1] \n\t" + "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[bias]] \n\t" + "vld1.32 {d1}, [%[p]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" -void Gemm::WriteWithAddPRelu(int mc, int nc, float *c, float *C, int ldc, - float *p, std::string mode, float *bias, - float *bias1) {} + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" -void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias) {} + "pld [%[c], #32] \n\t" + "pld [%[bias1], #32] \n\t" + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vld1.32 {q9, q10}, [%[bias1]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q3, q3, q9 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q5, q6}, [%[C]]! \n\t" -void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias) {} -void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, - float *new_scale, float *new_bias, float *bias1) { -} + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vld1.32 {q9, q10}, [%[bias1]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q3, q3, q9 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q5, q6}, [%[C]]! \n\t" -#endif // __ARM_NEON + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" -// 32位 float 矩阵乘法 -void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *bias) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int L1 = 32 * 1024; - int L2 = 512 * 1024; + "subs r6, r6, #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - KC = k; - MC = L1 / (KC * sizeof(float)); - NC = L2 / (KC * sizeof(float)); + "vld1.32 {q3}, [%[c]]! \n\t" + "vld1.32 {q9}, [%[bias1]]! \n\t" + "vadd.f32 q3, q3, q1 \n\t" + "vadd.f32 q3, q3, q9 \n\t" + "vmax.f32 q5, q3, q14 \n\t" + "vmin.f32 q7, q3, q14 \n\t" + "vmla.f32 q5, q7, q2 \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" - // make sure MC is multiple of MR, and NC is multiple of NR - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + "subs r6, r6, #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - packedC = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" - int mc, nc; - for (int j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); -#if __aarch64__ - // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); -#else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); -#endif - for (int i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); -#if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); - // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); -#else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); -#endif - if (bias == nullptr) { - InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, nullptr); - } else { - InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, bias + i); - } - } - } + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + "sub %[bias1], %[bias1], %[nc3] \n\t" - paddle_mobile::memory::Free(packedA); - paddle_mobile::memory::Free(packedB); - paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); -} + "vld1.32 {q4}, [%[c]]! \n\t" + "vld1.32 {q10}, [%[bias1]]! \n\t" + "vadd.f32 q4, q4, q1 \n\t" + "vadd.f32 q4, q4, q10 \n\t" + "vmax.f32 q6, q4, q14 \n\t" + "vmin.f32 q8, q4, q14 \n\t" + "vmla.f32 q6, q8, q2 \n\t" + "vst1.32 {q6}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" -void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, float *C, - int ldc, bool relu, float *new_scale, float *new_bias, - float *bias) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int L1 = 32 * 1024; - int L2 = 512 * 1024; + "add %[p], %[p], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" + "add %[c], %[c], %[step1] \n\t" + "add %[C], %[C], %[step] \n\t" + "add %[bias1], %[bias1], %[step] \n\t" - KC = k; - MC = L1 / (KC * sizeof(float)); - NC = L2 / (KC * sizeof(float)); + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" - // make sure MC is multiple of MR, and NC is multiple of NR - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; + : + : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), + [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), [p] "r"(p), + [bias] "r"(bias), [bias1] "r"(bias1) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10"); } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - packedC = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); +} - int mc, nc; - for (int j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); -#if __aarch64__ - // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); -#else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); -#endif - for (int i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); -#if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); - // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); -#else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); -#endif - if (bias == nullptr) { - InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, new_scale + i, new_bias + i); - } else { - InnerKernelWithBnAdd(mc, nc, alpha, packedA, packedB, beta, packedC, - &C(i, j), ldc, relu, new_scale + i, new_bias + i, - bias + i * ldc + j); +// C = A * B, batchnorm(C) +void Gemm::WriteWithBn(int mc, int nc, float *c, float *C, int ldc, + float *scale, float *bias) { + if (nc < 4) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + *C = (*c) * (*scale) + (*bias); + C++; + c++; } + C += (ldc - nc); + c += (NC - nc); + scale++; + bias++; } + return; } - paddle_mobile::memory::Free(packedA); - paddle_mobile::memory::Free(packedB); - paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); -} + int volatile nc1 = nc / 16; + int _nc1 = nc % 16; + int volatile nc2 = _nc1 / 4; + int volatile nc3 = 16 - 4 * (_nc1 % 4); + int volatile step = 4 * (ldc - nc); + int volatile step1 = 4 * (NC - nc); -void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, - const float *B, int ldb, float *C, int ldc, float *p, - std::string mode, float *bias, float *bias1) { - // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) - // L2 cache is 0.5~4 Mib (Contex-A72 cluster) - int L1 = 32 * 1024; - int L2 = 0.5 * 1024 * 1024; + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" - KC = k; - MC = L1 / (KC * sizeof(float)); - NC = L2 / (KC * sizeof(float)); + "mov r5, %[nc1] \n\t" + "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[scale]] \n\t" + "vld1.32 {d1}, [%[bias]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" - // make sure MC is multiple of MR, and NC is multiple of NR - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - packedC = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vmul.f32 q10, q3, q1 \n\t" + "vmul.f32 q11, q4, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - for (int l = 0; l < KC; ++l) { - zero[l] = 0; - } + "vld1.32 {q5, q6}, [%[c]]! \n\t" + "vmul.f32 q12, q5, q1 \n\t" + "vmul.f32 q13, q6, q1 \n\t" + "vadd.f32 q12, q12, q2 \n\t" + "vadd.f32 q13, q13, q2 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - int mc, nc; - for (int j = 0; j < n; j += NC) { - nc = s_min(n - j, NC); -#if __aarch64__ - // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); -#else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); -#endif - for (int i = 0; i < m; i += MC) { - mc = s_min(m - i, MC); -#if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); - // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); -#else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); -#endif - if (bias1 == nullptr) { - InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, - p + i, mode, bias + i, nullptr); - } else { - InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, - p + i, mode, bias + i, bias1 + i * ldc + j); - } - } - } + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - paddle_mobile::memory::Free(packedA); - paddle_mobile::memory::Free(packedB); - paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); + "subs r6, r6, #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q7}, [%[c]]! \n\t" + "vmul.f32 q10, q7, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + + "subs r6, r6, #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + + "vld1.32 {q8}, [%[c]]! \n\t" + "vmul.f32 q11, q8, q1 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vst1.32 {q11}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + "add %[scale], %[scale], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" + "add %[c], %[c], %[step1] \n\t" + "add %[C], %[C], %[step] \n\t" + + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), + [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), + [scale] "r"(scale), [bias] "r"(bias) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q10", "q11", "q12", "q13"); } -// 32位 float 矩阵乘法 -void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, - bool relu, float *bias) { -#ifndef __aarch64__ - if (m == 1 && bias == nullptr) { - return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu); +// C = A * B, batchnorm(C), relu(C) +void Gemm::WriteWithBnRelu(int mc, int nc, float *c, float *C, int ldc, + float *scale, float *bias) { + if (nc < 4) { + for (int i = 0; i < mc; ++i) { + for (int j = 0; j < nc; ++j) { + *C = (*c) * (*scale) + (*bias); + if (*C < 0) { + *C = 0; + } + C++; + c++; + } + C += (ldc - nc); + c += (NC - nc); + scale++; + bias++; + } + return; } -#endif // __aarch64__ -#ifdef _OPENMP - int max_threads = omp_get_max_threads(); -#else - int max_threads = 1; -#endif - // int L1 = 64 / max_threads * 1024; - int L = (max_threads > 2) ? 64 : 32; - int L1 = L / max_threads * 1024; - KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); - if (m > n) { - // 对 A 分块 - MC = L1 / (KC * sizeof(float)); - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // 补齐 B - NC = (n + NR - 1) / NR * NR; + int nc1 = nc / 16; + int _nc1 = nc % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); + int step = 4 * (ldc - nc); + int step1 = 4 * (NC - nc); -#if __aarch64__ - procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; - procAddDot = &Gemm::AddDot6x16; -#else - procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; - procAddDot = &Gemm::AddDot6x8; -#endif + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); - } else { - // 对 B 分块 - NC = L1 / (KC * sizeof(float)); - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // 补齐 A - MC = (m + MR - 1) / MR * MR; + "mov r5, %[nc1] \n\t" + "mov r6, %[nc2] \n\t" + "vld1.32 {d0}, [%[scale]] \n\t" + "vld1.32 {d1}, [%[bias]] \n\t" + "vdup.32 q1, d0[0] \n\t" + "vdup.32 q2, d1[0] \n\t" -#if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; - procPackB = &Gemm::PackMatrixB_16c; - procAddDot = &Gemm::AddDot6x16; -#else + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - procPackA = &Gemm::PackMatrixA_omp_6r; - procPackB = &Gemm::PackMatrixB_8c; - procAddDot = &Gemm::AddDot6x8; -#endif + "vld1.32 {q3, q4}, [%[c]]! \n\t" + "vmul.f32 q10, q3, q1 \n\t" + "vmul.f32 q11, q4, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); - } - packedC = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + "vld1.32 {q5, q6}, [%[c]]! \n\t" + "vmul.f32 q12, q5, q1 \n\t" + "vmul.f32 q13, q6, q1 \n\t" + "vadd.f32 q12, q12, q2 \n\t" + "vadd.f32 q13, q13, q2 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - if (m > n) { -#pragma omp parallel for - for (int i = 0; i < m; i += MC) { -#ifdef _OPENMP - int local_threads = omp_get_thread_num(); -#else - int local_threads = 0; -#endif + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - int mc; - mc = s_min(m - i, MC); - float *local_A = packedA + MC * KC * local_threads; - float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); - if (bias == nullptr) { - InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, - &C(i, 0), ldc, relu, nullptr); - } else { - InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, - &C(i, 0), ldc, relu, bias + i); - } - } - } else { -#pragma omp parallel for - for (int j = 0; j < n; j += NC) { -#ifdef _OPENMP - int local_threads = omp_get_thread_num(); -#else - int local_threads = 0; -#endif + "subs r6, r6, #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - int nc; - nc = s_min(n - j, NC); - float *local_B = packedB + KC * NC * local_threads; - float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); - InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, - &C(0, j), ldc, relu, bias); - } - } + "vld1.32 {q7}, [%[c]]! \n\t" + "vmul.f32 q10, q7, q1 \n\t" + "vadd.f32 q10, q10, q2 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" - paddle_mobile::memory::Free(packedA); - paddle_mobile::memory::Free(packedB); - paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); -} + "subs r6, r6, #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" -void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, - float *C, int ldc, bool relu, float *new_scale, - float *new_bias, float *bias) { -#ifdef _OPENMP - int max_threads = omp_get_max_threads(); -#else - int max_threads = 1; -#endif + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" - int L1 = 64 / max_threads * 1024; - KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); - if (m > n) { - // 对 A 分块 - MC = L1 / (KC * sizeof(float)); - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // 补齐 B - NC = (n + NR - 1) / NR * NR; + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" -#if __aarch64__ - procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; - procAddDot = &Gemm::AddDot6x16; -#else - procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; - procAddDot = &Gemm::AddDot6x8; -#endif + "vld1.32 {q8}, [%[c]]! \n\t" + "vmul.f32 q11, q8, q1 \n\t" + "vadd.f32 q11, q11, q2 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q11}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); - } else { - // 对 B 分块 - NC = L1 / (KC * sizeof(float)); - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // 补齐 A - MC = (m + MR - 1) / MR * MR; + "add %[scale], %[scale], #4 \n\t" + "add %[bias], %[bias], #4 \n\t" + "add %[c], %[c], %[step1] \n\t" + "add %[C], %[C], %[step] \n\t" -#if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; - procPackB = &Gemm::PackMatrixB_16c; - procAddDot = &Gemm::AddDot6x16; -#else - procPackA = &Gemm::PackMatrixA_omp_6r; - procPackB = &Gemm::PackMatrixB_8c; - procAddDot = &Gemm::AddDot6x8; -#endif + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); - } - packedC = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + : + : [C] "r"(C), [c] "r"(c), [mc] "r"(mc), [nc1] "r"(nc1), [nc2] "r"(nc2), + [nc3] "r"(nc3), [step] "r"(step), [step1] "r"(step1), + [scale] "r"(scale), [bias] "r"(bias) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q10", "q11", "q12", "q13", "q14"); +} - if (m > n) { -#pragma omp parallel for - for (int i = 0; i < m; i += MC) { -#ifdef _OPENMP - int local_threads = omp_get_thread_num(); -#else - int local_threads = 0; -#endif +// C = A * B, batchnorm(C),C = C + bias; relu(C) +void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, + float *new_scale, float *new_bias, float *bias) { + int nc1 = nc / 4; + int _nc1 = nc % 4; - int mc; - mc = s_min(m - i, MC); - float *local_A = packedA + MC * KC * local_threads; - float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); - if (bias == nullptr) { - InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, - &C(i, 0), ldc, relu, new_scale + i, new_bias + i); - } else { - InnerKernelWithBnAdd(mc, n, alpha, local_A, packedB, beta, local_C, - &C(i, 0), ldc, relu, new_scale + i, new_bias + i, - bias + i * ldc); - } + float *c_ptr, *C_ptr, *bias_ptr; + float32x4_t cv; + float32x4_t nbias; + float32x2_t scale; + float32x4_t biasv; + float32x4_t zero = vdupq_n_f32(0.0); + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * NC; + C_ptr = C + i * ldc; + bias_ptr = bias + i * ldc; + nbias = vld1q_dup_f32(new_bias); + scale = vld1_dup_f32(new_scale); + new_bias++; + new_scale++; + float scale0 = vget_lane_f32(scale, 0); + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + bias_ptr += 4; } - } else { -#pragma omp parallel for - for (int j = 0; j < n; j += NC) { -#ifdef _OPENMP - int local_threads = omp_get_thread_num(); -#else - int local_threads = 0; -#endif - - int nc; - nc = s_min(n - j, NC); - float *local_B = packedB + KC * NC * local_threads; - float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); - if (bias == nullptr) { - InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, - &C(0, j), ldc, relu, new_scale, new_bias); - } else { - InnerKernelWithBnAdd(m, nc, alpha, packedA, local_B, beta, local_C, - &C(0, j), ldc, relu, new_scale, new_bias, - bias + j); + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + biasv = vld1q_f32(bias_ptr); + cv = vmlaq_n_f32(nbias, cv, scale0); + cv = vaddq_f32(cv, biasv); + cv = vmaxq_f32(cv, zero); + if (_nc1 >= 1) { + vst1q_lane_f32(C_ptr, cv, 0); + C_ptr++; + } + if (_nc1 >= 2) { + vst1q_lane_f32(C_ptr, cv, 1); + C_ptr++; + } + if (_nc1 >= 3) { + vst1q_lane_f32(C_ptr, cv, 2); } } } +} - paddle_mobile::memory::Free(packedA); - paddle_mobile::memory::Free(packedB); - paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); +// C = A * B +void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); + + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vst1.32 {q0, q1}, [%[C]]! \n\t" + + "vld1.32 {q2, q3}, [%[c]]! \n\t" + "vst1.32 {q2, q3}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q4}, [%[c]]! \n\t" + "vst1.32 {q4}, [%[C]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + "vld1.32 {q5}, [%[c]]! \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); } -void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, - const float *B, int ldb, float *C, int ldc, - float *p, std::string mode, float *bias, - float *bias1) { -#ifdef _OPENMP - int max_threads = omp_get_max_threads(); -#else - int max_threads = 1; -#endif +// C = alpha * A * B + beta * C +void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} - int L1 = 8 * 1024; - KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); - if (m > n) { - // 对 A 分块 - MC = L1 / (KC * sizeof(float)); - if (MC == 0) { - MC = MR; - } else { - int mblock_num = (m + MC - 1) / MC; - MC = (m + mblock_num - 1) / mblock_num; - MC = (MC + MR - 1) / MR * MR; - } - // 补齐 B - NC = (n + NR - 1) / NR * NR; +// C = A * B + C +void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; -#if __aarch64__ - procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; - procAddDot = &Gemm::AddDot6x16; -#else - procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; - procAddDot = &Gemm::AddDot6x8; -#endif + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); - } else { - // 对 B 分块 - NC = L1 / (KC * sizeof(float)); - if (NC == 0) { - NC = NR; - } else { - int nblock_num = (n + NC - 1) / NC; - NC = (n + nblock_num - 1) / nblock_num; - NC = (NC + NR - 1) / NR * NR; - } - // 补齐 A - MC = (m + MR - 1) / MR * MR; + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[C]] \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" -#if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; - procPackB = &Gemm::PackMatrixB_16c; - procAddDot = &Gemm::AddDot6x16; -#else - procPackA = &Gemm::PackMatrixA_omp_6r; - procPackB = &Gemm::PackMatrixB_8c; - procAddDot = &Gemm::AddDot6x8; -#endif + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[C]] \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - packedA = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); - packedB = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); - } - packedC = static_cast( - paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - if (m > n) { -#pragma omp parallel for - for (int i = 0; i < m; i += MC) { -#ifdef _OPENMP - int local_threads = omp_get_thread_num(); -#else - int local_threads = 0; -#endif + : [C] "+r"(C), [c] "+r"(c) + : [nc1] "r"(nc1) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); - int mc; - mc = s_min(m - i, MC); - float *local_A = packedA + MC * KC * local_threads; - float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); - if (bias1 == nullptr) { - InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, - p + i, mode, bias + i, nullptr); - } else { - InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, - p + i, mode, bias + i, bias1 + i * ldc); - } + if (_nc1 != 0) { + for (int j = 0; j < _nc1; j++) { + *C++ += *c++; } - } else { -#pragma omp parallel for - for (int j = 0; j < n; j += NC) { -#ifdef _OPENMP - int local_threads = omp_get_thread_num(); -#else - int local_threads = 0; -#endif + } +} - int nc; - nc = s_min(n - j, NC); - float *local_B = packedB + KC * NC * local_threads; - float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); - if (bias1 == nullptr) { - InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, - mode, bias, nullptr); - } else { - InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, - mode, bias, bias1 + j); +// C = A * B + C, relu(C) +void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; + + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[C]] \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" + + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[C]] \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + : [C] "+r"(C), [c] "+r"(c) + : [nc1] "r"(nc1) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); + + if (_nc1 != 0) { + for (int j = 0; j < _nc1; j++) { + *C += *c; + if (*C < 0) { + *C = 0; } + C++; + c++; } } +} + +// C = A * B, batchnorm(C) +void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, + float *bias) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); + + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[scale]]! \n\t" + "vld1.32 {q10, q11}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q2 \n\t" + "vmla.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" + + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[scale]]! \n\t" + "vld1.32 {q12, q13}, [%[bias]]! \n\t" + "vmla.f32 q12, q4, q6 \n\t" + "vmla.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - paddle_mobile::memory::Free(packedA); - paddle_mobile::memory::Free(packedB); - paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); -} + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" -void Gemm::AddDot6x8(int k, const float *a, const float *b, float *c, int ldc) { -#if __ARM_NEON -#if __aarch64__ + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" - // init C - float32x4_t cv0 = vdupq_n_f32(0.0); - float32x4_t cv1 = vdupq_n_f32(0.0); - float32x4_t cv2 = vdupq_n_f32(0.0); - float32x4_t cv3 = vdupq_n_f32(0.0); - float32x4_t cv4 = vdupq_n_f32(0.0); - float32x4_t cv5 = vdupq_n_f32(0.0); - float32x4_t cv6 = vdupq_n_f32(0.0); - float32x4_t cv7 = vdupq_n_f32(0.0); - float32x4_t cv8 = vdupq_n_f32(0.0); - float32x4_t cv9 = vdupq_n_f32(0.0); - float32x4_t cv10 = vdupq_n_f32(0.0); - float32x4_t cv11 = vdupq_n_f32(0.0); + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - float32x4_t av; - float32x4_t bv0; - float32x4_t bv1; + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" - float32x2_t av01; - float32x2_t av23; - float32x2_t av45; + "sub %[c], %[c], %[nc3] \n\t" + "sub %[scale], %[scale], %[nc3] \n\t" + "sub %[bias], %[bias], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" - for (int p = 0; p < k; p += 1) { - av = vld1q_f32(a); - av01 = vget_low_f32(av); - av23 = vget_high_f32(av); - av45 = vld1_f32(a + 4); - bv0 = vld1q_f32(b); - bv1 = vld1q_f32(b + 4); + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" - cv0 = vmlaq_lane_f32(cv0, bv0, av01, 0); - cv1 = vmlaq_lane_f32(cv1, bv1, av01, 0); - cv2 = vmlaq_lane_f32(cv2, bv0, av01, 1); - cv3 = vmlaq_lane_f32(cv3, bv1, av01, 1); + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3), + [scale] "r"(scale), [bias] "r"(bias) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); +} - cv4 = vmlaq_lane_f32(cv4, bv0, av23, 0); - cv5 = vmlaq_lane_f32(cv5, bv1, av23, 0); - cv6 = vmlaq_lane_f32(cv6, bv0, av23, 1); - cv7 = vmlaq_lane_f32(cv7, bv1, av23, 1); +// C = A * B, batchnorm(C), relu(C) +void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *scale, + float *bias) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); - cv8 = vmlaq_lane_f32(cv8, bv0, av45, 0); - cv9 = vmlaq_lane_f32(cv9, bv1, av45, 0); - cv10 = vmlaq_lane_f32(cv10, bv0, av45, 1); - cv11 = vmlaq_lane_f32(cv11, bv1, av45, 1); + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - a += MR; - b += NR; - } + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[scale]]! \n\t" + "vld1.32 {q10, q11}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q2 \n\t" + "vmla.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - vst1q_f32(c, cv0); - vst1q_f32(c + 4, cv1); - vst1q_f32(c + ldc, cv2); - vst1q_f32(c + ldc + 4, cv3); - vst1q_f32(c + 2 * ldc, cv4); - vst1q_f32(c + 2 * ldc + 4, cv5); - vst1q_f32(c + 3 * ldc, cv6); - vst1q_f32(c + 3 * ldc + 4, cv7); - vst1q_f32(c + 4 * ldc, cv8); - vst1q_f32(c + 4 * ldc + 4, cv9); - vst1q_f32(c + 5 * ldc, cv10); - vst1q_f32(c + 5 * ldc + 4, cv11); + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[scale]]! \n\t" + "vld1.32 {q12, q13}, [%[bias]]! \n\t" + "vmla.f32 q12, q4, q6 \n\t" + "vmla.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" -#else + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - const float *a_ptr, *b_ptr; - a_ptr = a; - b_ptr = b; - int kc1 = k / 8; - int kc2 = k % 8; - int step = sizeof(float) * ldc; - asm volatile( - "pld [%[a_ptr]] \n\t" - "pld [%[a_ptr], #64] \n\t" - "pld [%[b_ptr]] \n\t" - "pld [%[b_ptr], #64] \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - "vmov.f32 q4, #0.0 \n\t" - "vmov.f32 q5, #0.0 \n\t" - "vmov.f32 q6, #0.0 \n\t" - "vmov.f32 q7, #0.0 \n\t" - "vmov.f32 q8, #0.0 \n\t" - "vmov.f32 q9, #0.0 \n\t" - "vmov.f32 q10, #0.0 \n\t" - "vmov.f32 q11, #0.0 \n\t" - "vmov.f32 q12, #0.0 \n\t" - "vmov.f32 q13, #0.0 \n\t" - "vmov.f32 q14, #0.0 \n\t" - "vmov.f32 q15, #0.0 \n\t" + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" - "subs %[kc1], %[kc1], #1 \n\t" - "blt 2f \n\t" - "1: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - "pld [%[a_ptr], #128] \n\t" - "pld [%[b_ptr], #128] \n\t" + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + "sub %[c], %[c], %[nc3] \n\t" + "sub %[scale], %[scale], %[nc3] \n\t" + "sub %[bias], %[bias], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3), + [scale] "r"(scale), [bias] "r"(bias) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13", "q14"); +} - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" +#endif // __aarch64__ +#endif // __ARM_NEON - "pld [%[a_ptr], #128] \n\t" - "pld [%[b_ptr], #128] \n\t" +// 32位 float 矩阵乘法 +void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int L1 = 32 * 1024; + int L2 = 512 * 1024; - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + KC = k; + MC = L1 / (KC * sizeof(float)); + NC = L2 / (KC * sizeof(float)); - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + // make sure MC is multiple of MR, and NC is multiple of NR + if (MC == 0) { + MC = MR; + } else { + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR; + } else { + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; + + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + + int mc, nc; + for (int j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); +#if __aarch64__ + // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#else + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#endif + for (int i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); +#if __aarch64__ + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#else + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#endif + if (bias == nullptr) { + InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, nullptr); + } else { + InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, bias + i); + } + } + } - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" +void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, float *C, + int ldc, bool relu, float *new_scale, float *new_bias, + float *bias) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int L1 = 32 * 1024; + int L2 = 512 * 1024; - "pld [%[a_ptr], #128] \n\t" - "pld [%[b_ptr], #128] \n\t" + KC = k; + MC = L1 / (KC * sizeof(float)); + NC = L2 / (KC * sizeof(float)); - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + // make sure MC is multiple of MR, and NC is multiple of NR + if (MC == 0) { + MC = MR; + } else { + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR; + } else { + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + int mc, nc; + for (int j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); +#if __aarch64__ + // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#else + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#endif + for (int i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); +#if __aarch64__ + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#else + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#endif + if (bias == nullptr) { + InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, new_scale + i, new_bias + i); + } else { + InnerKernelWithBnAdd(mc, nc, alpha, packedA, packedB, beta, packedC, + &C(i, j), ldc, relu, new_scale + i, new_bias + i, + bias + i * ldc + j); + } + } + } - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} - "pld [%[a_ptr], #128] \n\t" - "pld [%[b_ptr], #128] \n\t" +void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, + const float *B, int ldb, float *C, int ldc, float *p, + std::string mode, float *bias, float *bias1) { + // L1 data cache is 32 kib (Per Contex-A57, Contex-A72, Contex-A73) + // L2 cache is 0.5~4 Mib (Contex-A72 cluster) + int L1 = 32 * 1024; + int L2 = 0.5 * 1024 * 1024; - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + KC = k; + MC = L1 / (KC * sizeof(float)); + NC = L2 / (KC * sizeof(float)); - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + // make sure MC is multiple of MR, and NC is multiple of NR + if (MC == 0) { + MC = MR; + } else { + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // DLOG << "mblock_num = " << mblock_num << ", MC = " << MC << "\n"; + if (NC == 0) { + NC = NR; + } else { + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // DLOG << "nblock_num = " << nblock_num << ", NC = " << NC << "\n"; - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + for (int l = 0; l < KC; ++l) { + zero[l] = 0; + } - "subs %[kc1], %[kc1], #1 \n\t" - "bge 1b \n\t" - "2: \n\t" + int mc, nc; + for (int j = 0; j < n; j += NC) { + nc = s_min(n - j, NC); +#if __aarch64__ + // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#else + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); +#endif + for (int i = 0; i < m; i += MC) { + mc = s_min(m - i, MC); +#if __aarch64__ + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#else + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); +#endif + if (bias1 == nullptr) { + InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, + p + i, mode, bias + i, nullptr); + } else { + InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, + p + i, mode, bias + i, bias1 + i * ldc + j); + } + } + } - "subs %[kc2], %[kc2], #1 \n\t" - "blt 4f \n\t" - "3: \n\t" + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} - "vld1.32 {d0-d2}, [%[a_ptr]]! \n\t" - "vld1.32 {q2, q3}, [%[b_ptr]]! \n\t" +// 32位 float 矩阵乘法 +void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu, float *bias) { +#ifndef __aarch64__ + if (m == 1 && bias == nullptr) { + return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu); + } +#endif // __aarch64__ +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif - "vmla.f32 q4, q2, d0[0] \n\t" - "vmla.f32 q5, q3, d0[0] \n\t" - "vmla.f32 q6, q2, d0[1] \n\t" - "vmla.f32 q7, q3, d0[1] \n\t" - "vmla.f32 q8, q2, d1[0] \n\t" - "vmla.f32 q9, q3, d1[0] \n\t" - "vmla.f32 q10, q2, d1[1] \n\t" - "vmla.f32 q11, q3, d1[1] \n\t" - "vmla.f32 q12, q2, d2[0] \n\t" - "vmla.f32 q13, q3, d2[0] \n\t" - "vmla.f32 q14, q2, d2[1] \n\t" - "vmla.f32 q15, q3, d2[1] \n\t" + // int L1 = 64 / max_threads * 1024; + int L = (max_threads > 2) ? 64 : 32; + int L1 = L / max_threads * 1024; + KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + if (MC == 0) { + MC = MR; + } else { + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // 补齐 B + NC = (n + NR - 1) / NR * NR; - "subs %[kc2], %[kc2], #1 \n\t" - "bge 3b \n\t" - "4: \n\t" +#if __aarch64__ + procPackA = &Gemm::PackMatrixA_6r; + procPackB = &Gemm::PackMatrixB_omp_16c; + procAddDot = &Gemm::AddDot6x16; +#else + procPackA = &Gemm::PackMatrixA_6r; + procPackB = &Gemm::PackMatrixB_omp_8c; + procAddDot = &Gemm::AddDot6x8; +#endif - "mov r5, %[c] \n\t" - "mov r6, %[step] \n\t" - "vst1.32 {q4, q5}, [r5], r6 \n\t" - "vst1.32 {q6, q7}, [r5], r6 \n\t" - "vst1.32 {q8, q9}, [r5], r6 \n\t" - "vst1.32 {q10, q11}, [r5], r6 \n\t" - "vst1.32 {q12, q13}, [r5], r6 \n\t" - "vst1.32 {q14, q15}, [r5] \n\t" + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + if (NC == 0) { + NC = NR; + } else { + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // 补齐 A + MC = (m + MR - 1) / MR * MR; - : - : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [kc2] "r"(kc2), [step] "r"(step) - : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", - "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#if __aarch64__ + procPackA = &Gemm::PackMatrixA_omp_6r; + procPackB = &Gemm::PackMatrixB_16c; + procAddDot = &Gemm::AddDot6x16; +#else -#endif // __aarch64__ + procPackA = &Gemm::PackMatrixA_omp_6r; + procPackB = &Gemm::PackMatrixB_8c; + procAddDot = &Gemm::AddDot6x8; +#endif -#endif // __ARM_NEON -} + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); -#if __aarch64__ -void Gemm::AddDot8x12(int k, const float *a, const float *b, float *c, - int ldc) { - const float *a_ptr, *b_ptr; - a_ptr = a; - b_ptr = b; - int kc1 = k; - int step = 4 * ldc; - asm volatile( - "dup v5.4s, wzr \n\t" - "dup v6.4s, wzr \n\t" - "dup v7.4s, wzr \n\t" - "dup v8.4s, wzr \n\t" - "dup v9.4s, wzr \n\t" - "dup v10.4s, wzr \n\t" - "dup v11.4s, wzr \n\t" - "dup v12.4s, wzr \n\t" - "dup v13.4s, wzr \n\t" - "dup v14.4s, wzr \n\t" - "dup v15.4s, wzr \n\t" - "dup v16.4s, wzr \n\t" + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif - "dup v17.4s, wzr \n\t" - "dup v18.4s, wzr \n\t" - "dup v19.4s, wzr \n\t" - "dup v20.4s, wzr \n\t" - "dup v21.4s, wzr \n\t" - "dup v22.4s, wzr \n\t" - "dup v23.4s, wzr \n\t" - "dup v24.4s, wzr \n\t" - "dup v25.4s, wzr \n\t" - "dup v26.4s, wzr \n\t" - "dup v27.4s, wzr \n\t" - "dup v28.4s, wzr \n\t" + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + if (bias == nullptr) { + InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, nullptr); + } else { + InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, bias + i); + } + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif - "subs %[kc1], %[kc1], #1 \n\t" - "blt 2f \n\t" - "1: \n\t" + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, bias); + } + } - "prfm pldl1keep, [%[a_ptr], #32] \n\t" - "prfm pldl1keep, [%[b_ptr], #48] \n\t" + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} - "ld1 {v0.4s, v1.4s}, [%[a_ptr]], #32 \n\t" - "ld1 {v2.4s, v3.4s, v4.4s}, [%[b_ptr]], #48 \n\t" +void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, + float *C, int ldc, bool relu, float *new_scale, + float *new_bias, float *bias) { +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif - "fmla v5.4s, v2.4s, v0.s[0] \n\t" - "fmla v6.4s, v3.4s, v0.s[0] \n\t" - "fmla v7.4s, v4.4s, v0.s[0] \n\t" - "fmla v8.4s, v2.4s, v0.s[1] \n\t" - "fmla v9.4s, v3.4s, v0.s[1] \n\t" - "fmla v10.4s, v4.4s, v0.s[1] \n\t" - "fmla v11.4s, v2.4s, v0.s[2] \n\t" - "fmla v12.4s, v3.4s, v0.s[2] \n\t" - "fmla v13.4s, v4.4s, v0.s[2] \n\t" - "fmla v14.4s, v2.4s, v0.s[3] \n\t" - "fmla v15.4s, v3.4s, v0.s[3] \n\t" - "fmla v16.4s, v4.4s, v0.s[3] \n\t" + int L1 = 64 / max_threads * 1024; + KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + if (MC == 0) { + MC = MR; + } else { + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // 补齐 B + NC = (n + NR - 1) / NR * NR; - "fmla v17.4s, v2.4s, v1.s[0] \n\t" - "fmla v18.4s, v3.4s, v1.s[0] \n\t" - "fmla v19.4s, v4.4s, v1.s[0] \n\t" - "fmla v20.4s, v2.4s, v1.s[1] \n\t" - "fmla v21.4s, v3.4s, v1.s[1] \n\t" - "fmla v22.4s, v4.4s, v1.s[1] \n\t" - "fmla v23.4s, v2.4s, v1.s[2] \n\t" - "fmla v24.4s, v3.4s, v1.s[2] \n\t" - "fmla v25.4s, v4.4s, v1.s[2] \n\t" - "fmla v26.4s, v2.4s, v1.s[3] \n\t" - "fmla v27.4s, v3.4s, v1.s[3] \n\t" - "fmla v28.4s, v4.4s, v1.s[3] \n\t" +#if __aarch64__ + procPackA = &Gemm::PackMatrixA_6r; + procPackB = &Gemm::PackMatrixB_omp_16c; + procAddDot = &Gemm::AddDot6x16; +#else + procPackA = &Gemm::PackMatrixA_6r; + procPackB = &Gemm::PackMatrixB_omp_8c; + procAddDot = &Gemm::AddDot6x8; +#endif - "subs %[kc1], %[kc1], #1 \n\t" - "bge 1b \n\t" - "2: \n\t" + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + if (NC == 0) { + NC = NR; + } else { + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // 补齐 A + MC = (m + MR - 1) / MR * MR; - "st1 {v5.4s, v6.4s, v7.4s}, [%[c]], %[step] \n\t" - "st1 {v8.4s, v9.4s, v10.4s}, [%[c]], %[step] \n\t" - "st1 {v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" - "st1 {v14.4s, v15.4s, v16.4s}, [%[c]], %[step] \n\t" - "st1 {v17.4s, v18.4s, v19.4s}, [%[c]], %[step] \n\t" - "st1 {v20.4s, v21.4s, v22.4s}, [%[c]], %[step] \n\t" - "st1 {v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" - "st1 {v26.4s, v27.4s, v28.4s}, [%[c]], %[step] \n\t" - : - : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [step] "r"(step) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28"); -} +#if __aarch64__ + procPackA = &Gemm::PackMatrixA_omp_6r; + procPackB = &Gemm::PackMatrixB_16c; + procAddDot = &Gemm::AddDot6x16; +#else + procPackA = &Gemm::PackMatrixA_omp_6r; + procPackB = &Gemm::PackMatrixB_8c; + procAddDot = &Gemm::AddDot6x8; +#endif -void Gemm::AddDot6x16(int k, const float *a, const float *b, float *c, - int ldc) { - const float *a_ptr, *b_ptr; - a_ptr = a; - b_ptr = b; - int kc1 = k; - int step = 4 * ldc; - int step1 = 4 * 6; - asm volatile( + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); - "dup v6.4s, wzr \n\t" - "dup v7.4s, wzr \n\t" - "dup v8.4s, wzr \n\t" - "dup v9.4s, wzr \n\t" - "dup v10.4s, wzr \n\t" - "dup v11.4s, wzr \n\t" - "dup v12.4s, wzr \n\t" - "dup v13.4s, wzr \n\t" + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif - "dup v14.4s, wzr \n\t" - "dup v15.4s, wzr \n\t" - "dup v16.4s, wzr \n\t" - "dup v17.4s, wzr \n\t" - "dup v18.4s, wzr \n\t" - "dup v19.4s, wzr \n\t" - "dup v20.4s, wzr \n\t" - "dup v21.4s, wzr \n\t" + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + if (bias == nullptr) { + InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, new_scale + i, new_bias + i); + } else { + InnerKernelWithBnAdd(mc, n, alpha, local_A, packedB, beta, local_C, + &C(i, 0), ldc, relu, new_scale + i, new_bias + i, + bias + i * ldc); + } + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif - "dup v22.4s, wzr \n\t" - "dup v23.4s, wzr \n\t" - "dup v24.4s, wzr \n\t" - "dup v25.4s, wzr \n\t" - "dup v26.4s, wzr \n\t" - "dup v27.4s, wzr \n\t" - "dup v28.4s, wzr \n\t" - "dup v29.4s, wzr \n\t" + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + if (bias == nullptr) { + InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, new_scale, new_bias); + } else { + InnerKernelWithBnAdd(m, nc, alpha, packedA, local_B, beta, local_C, + &C(0, j), ldc, relu, new_scale, new_bias, + bias + j); + } + } + } - "subs %[kc1], %[kc1], #1 \n\t" - "blt 2f \n\t" - "1: \n\t" + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); +} - "prfm pldl1keep, [%[a_ptr], #24] \n\t" - "prfm pldl1keep, [%[b_ptr], #64] \n\t" +void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, + const float *B, int ldb, float *C, int ldc, + float *p, std::string mode, float *bias, + float *bias1) { +#ifdef _OPENMP + int max_threads = omp_get_max_threads(); +#else + int max_threads = 1; +#endif - "ld1 {v0.4s, v1.4s}, [%[a_ptr]], %[step1] \n\t" - "ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[b_ptr]], #64 \n\t" + int L1 = 8 * 1024; + KC = k; + zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); + memset(static_cast(zero), 0, sizeof(float) * KC); + if (m > n) { + // 对 A 分块 + MC = L1 / (KC * sizeof(float)); + if (MC == 0) { + MC = MR; + } else { + int mblock_num = (m + MC - 1) / MC; + MC = (m + mblock_num - 1) / mblock_num; + MC = (MC + MR - 1) / MR * MR; + } + // 补齐 B + NC = (n + NR - 1) / NR * NR; - "fmla v6.4s, v2.4s, v0.s[0] \n\t" - "fmla v7.4s, v3.4s, v0.s[0] \n\t" - "fmla v8.4s, v4.4s, v0.s[0] \n\t" - "fmla v9.4s, v5.4s, v0.s[0] \n\t" +#if __aarch64__ + procPackA = &Gemm::PackMatrixA_6r; + procPackB = &Gemm::PackMatrixB_omp_16c; + procAddDot = &Gemm::AddDot6x16; +#else + procPackA = &Gemm::PackMatrixA_6r; + procPackB = &Gemm::PackMatrixB_omp_8c; + procAddDot = &Gemm::AddDot6x8; +#endif - "fmla v10.4s, v2.4s, v0.s[1] \n\t" - "fmla v11.4s, v3.4s, v0.s[1] \n\t" - "fmla v12.4s, v4.4s, v0.s[1] \n\t" - "fmla v13.4s, v5.4s, v0.s[1] \n\t" + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); + } else { + // 对 B 分块 + NC = L1 / (KC * sizeof(float)); + if (NC == 0) { + NC = NR; + } else { + int nblock_num = (n + NC - 1) / NC; + NC = (n + nblock_num - 1) / nblock_num; + NC = (NC + NR - 1) / NR * NR; + } + // 补齐 A + MC = (m + MR - 1) / MR * MR; - "fmla v14.4s, v2.4s, v0.s[2] \n\t" - "fmla v15.4s, v3.4s, v0.s[2] \n\t" - "fmla v16.4s, v4.4s, v0.s[2] \n\t" - "fmla v17.4s, v5.4s, v0.s[2] \n\t" +#if __aarch64__ + procPackA = &Gemm::PackMatrixA_omp_6r; + procPackB = &Gemm::PackMatrixB_16c; + procAddDot = &Gemm::AddDot6x16; +#else + procPackA = &Gemm::PackMatrixA_omp_6r; + procPackB = &Gemm::PackMatrixB_8c; + procAddDot = &Gemm::AddDot6x8; +#endif - "fmla v18.4s, v2.4s, v0.s[3] \n\t" - "fmla v19.4s, v3.4s, v0.s[3] \n\t" - "fmla v20.4s, v4.4s, v0.s[3] \n\t" - "fmla v21.4s, v5.4s, v0.s[3] \n\t" + packedA = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + packedB = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); + } + packedC = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * MC * NC * max_threads)); - "fmla v22.4s, v2.4s, v1.s[0] \n\t" - "fmla v23.4s, v3.4s, v1.s[0] \n\t" - "fmla v24.4s, v4.4s, v1.s[0] \n\t" - "fmla v25.4s, v5.4s, v1.s[0] \n\t" + if (m > n) { +#pragma omp parallel for + for (int i = 0; i < m; i += MC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif - "fmla v26.4s, v2.4s, v1.s[1] \n\t" - "fmla v27.4s, v3.4s, v1.s[1] \n\t" - "fmla v28.4s, v4.4s, v1.s[1] \n\t" - "fmla v29.4s, v5.4s, v1.s[1] \n\t" + int mc; + mc = s_min(m - i, MC); + float *local_A = packedA + MC * KC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + if (bias1 == nullptr) { + InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, + p + i, mode, bias + i, nullptr); + } else { + InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, + p + i, mode, bias + i, bias1 + i * ldc); + } + } + } else { +#pragma omp parallel for + for (int j = 0; j < n; j += NC) { +#ifdef _OPENMP + int local_threads = omp_get_thread_num(); +#else + int local_threads = 0; +#endif - "subs %[kc1], %[kc1], #1 \n\t" - "bge 1b \n\t" - "2: \n\t" + int nc; + nc = s_min(n - j, NC); + float *local_B = packedB + KC * NC * local_threads; + float *local_C = packedC + MC * NC * local_threads; + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + if (bias1 == nullptr) { + InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, + mode, bias, nullptr); + } else { + InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, + mode, bias, bias1 + j); + } + } + } - "st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t" - "st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" - "st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t" - "st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t" - "st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" - "st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t" - : - : [a_ptr] "r"(a_ptr), [b_ptr] "r"(b_ptr), [c] "r"(c), [kc1] "r"(kc1), - [step] "r"(step), [step1] "r"(step1) - : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", - "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", - "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"); + paddle_mobile::memory::Free(packedA); + paddle_mobile::memory::Free(packedB); + paddle_mobile::memory::Free(packedC); + paddle_mobile::memory::Free(zero); } -#endif // __aarch64__ - } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index b6af1c838d337fb84aae58e1531ab0acb07f7bea..effab20b2045fbe93590189e28bac24d1f72ab2c 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -46,15 +46,6 @@ namespace math { class Gemm { public: - /* -// 将 A 矩阵分块复制到连续内存(ColMajor) -void PackMatrixA(int m, int k, int m_tail, const float *A, int lda, - float *buffer); - -// 将 B 矩阵分块复制到连续内存(ColMajor) -void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); -*/ typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, int); @@ -62,31 +53,31 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, FnPack procPackB; FnAddDot procAddDot; - // 将 A 矩阵分块复制到连续内存(RowMajor) + // 将 A\B 矩阵分块复制到连续内存(RowMajor) void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, float *buffer); void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer); - void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, float *buffer); + void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, + float *buffer); void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, float *buffer); - - // 将 B 矩阵分块复制到连续内存(RowMajor) void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); - void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); +#if __aarch64__ + void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); + void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, + float *buffer); void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, float *buffer); +#endif // 分块矩阵乘法 void InnerKernel(int mc, int nc, float alpha, const float *a, const float *b, @@ -106,22 +97,16 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); - // 向量矩阵乘法 (M = 1) - void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, - const float *B, int ldb, float beta, float *C, int ldc, - bool relu); - /* - void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, float - *C, int ldc, bool relu, float *new_scale, float *new_bias); - */ - // 计算一个更小的 C 矩阵分块 - void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); - void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc); +#if __aarch64__ void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); void AddDot8x12(int k, const float *a, const float *b, float *c, int ldc); void AddDot6x16(int k, const float *a, const float *b, float *c, int ldc); +#else + void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); + void AddDot4x8(int k, const float *a, const float *b, float *c, int ldc); + void AddDot6x8(int k, const float *a, const float *b, float *c, int ldc); +#endif // 分块矩阵乘法结果回写 // C = A * B @@ -149,6 +134,18 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias, float *bias1); + // 向量矩阵乘法 (M = 1) +#if __aarch64__ +#else + void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, + const float *B, int ldb, float beta, float *C, int ldc, + bool relu); + + void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, + float *C, int ldc, bool relu, float *new_scale, + float *new_bias); + // 向量矩阵乘法结果回写 // C = A * B void VecWriteBasic(int n, float *c, float *C, int ldc); @@ -158,14 +155,13 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void VecWriteWithAdd(int n, float *c, float *C, int ldc); // C = A * B + C, relu(C) void VecWriteWithAddRelu(int n, float *c, float *C, int ldc); - /* - // C = A * B, batchnorm(C) - void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, - float *new_bias); - // C = A * B, batchnorm(C), relu(C) - void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float - *new_scale, float *new_bias); - */ + // C = A * B, batchnorm(C) + void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, + float *new_bias); + // C = A * B, batchnorm(C), relu(C) + void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, + float *new_bias); +#endif // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, diff --git a/src/operators/op_param.h b/src/operators/op_param.h index a877b27eb19f66011f02d361a1d4b5ea459be51a..00fbfbc771cfe9329b8ba76f120a5bc304dc80fc 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -3066,5 +3066,52 @@ class ReadFromArrayParam : public OpParam { }; #endif +#ifdef IS_EMPTY_OP +template +class IsEmptyParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + IsEmptyParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + } + + const GType *InputX() const { return input_x_; } + GType *Out() const { return output_; } + + public: + GType *input_x_; + GType *output_; +}; +#endif // IS_EMPTY_OP + +#ifdef INCREMENT_OP +template +class IncrementParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + IncrementParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + output_ = OutFrom(outputs, scope); + step_ = OpParam::GetAttr("step", attrs); + } + + const GType *InputX() const { return input_x_; } + GType *Out() const { return output_; } + int Step() const { return step_; } + + public: + GType *input_x_; + GType *output_; + int step_; +}; +#endif // INCREMENT_OP + } // namespace operators } // namespace paddle_mobile diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8b52faf184bf79211b39ce46ae21e0668d1dafc2..f3dffbad1c065561d86da0e976792d206198c61e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -437,6 +437,14 @@ if (NOT FOUND_MATCH) ADD_EXECUTABLE(test-logical-xor-op operators/test_logical_xor_op.cpp test_helper.h test_include.h) target_link_libraries(test-logical-xor-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-increment-op operators/test_increment_op.cpp test_helper.h test_include.h) + target_link_libraries(test-increment-op paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-is-empty-op operators/test_is_empty_op.cpp test_helper.h test_include.h) + target_link_libraries(test-is-empty-op paddle-mobile) + ADD_EXECUTABLE(test-conv-bn-relu-op operators/test_conv_bn_relu_op.cpp test_helper.h test_include.h) target_link_libraries(test-conv-bn-relu-op paddle-mobile) diff --git a/test/operators/test_increment_op.cpp b/test/operators/test_increment_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cb1fc9478c65dd70a758dfdf4a3f795470f720af --- /dev/null +++ b/test/operators/test_increment_op.cpp @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_include.h" +#include "operators/increment_op.h" + +namespace paddle_mobile { + +template +void Increment(const framework::Tensor *input, framework::Tensor *out, + int step) { + auto input_data = input->data(); + auto out_data = out->data(); + *out_data = *input_data + step; +} + +int TestIncrementOp(const std::vector input_shape, int step) { + framework::DDim input_dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"inputX"}); + outputs["Out"] = std::vector({"output"}); + + auto x_var = scope.get()->Var("inputX"); + auto x = x_var->template GetMutable(); + SetupTensor(x, input_dims, 0, 100); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + attrs["step"].Set(step); + + auto *op = new operators::IncrementOp("increment", inputs, + outputs, attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + framework::Tensor output_cmp; + float *output_cmp_data = output_cmp.mutable_data(output->dims()); + Increment(x, &output_cmp, step); + + const float *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + float gap = output_data[i] - output_cmp_data[i]; + if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } +} +} // namespace paddle_mobile + +int main() { + paddle_mobile::TestIncrementOp({1}, 4); + paddle_mobile::TestIncrementOp({1}, 10); + DLOG << "test increment op pass."; + return 0; +} diff --git a/test/operators/test_is_empty_op.cpp b/test/operators/test_is_empty_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5283b2bd69e47ece6d7569f3b68706008c89ef94 --- /dev/null +++ b/test/operators/test_is_empty_op.cpp @@ -0,0 +1,69 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_include.h" +#include "operators/is_empty_op.h" + +namespace paddle_mobile { + +void IsEmpty(const framework::Tensor *input, framework::Tensor *out) { + out->data()[0] = input->numel() == 0; +} + +int TestIsEmptyOp(const std::vector input_shape) { + framework::DDim input_dims = framework::make_ddim(input_shape); + VariableNameMap inputs; + VariableNameMap outputs; + auto scope = std::make_shared(); + inputs["X"] = std::vector({"inputX"}); + outputs["Out"] = std::vector({"output"}); + + auto x_var = scope.get()->Var("inputX"); + auto x = x_var->template GetMutable(); + SetupTensor(x, input_dims, 0, 100); + + auto output_var = scope.get()->Var("output"); + framework::AttributeMap attrs; + + auto *op = new operators::IsEmptyOp("is_empty", inputs, outputs, + attrs, scope); + + op->InferShape(); + op->Init(); + op->Run(); + + auto output = output_var->template Get(); + framework::Tensor output_cmp; + bool *output_cmp_data = output_cmp.mutable_data(output->dims()); + IsEmpty(x, &output_cmp); + + const bool *output_data = output->data(); + for (int i = 0; i < output->numel(); ++i) { + if (output_data[i] != output_cmp_data[i]) { + LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i] + << ", output_cmp_data[" << i + << "] = " << output_cmp_data[i]; + delete op; + exit(1); + } + } +} +} // namespace paddle_mobile + +int main() { + paddle_mobile::TestIsEmptyOp({1, 3, 100, 100}); + paddle_mobile::TestIsEmptyOp({0}); + DLOG << "test is_empty op pass."; + return 0; +} diff --git a/tools/op.cmake b/tools/op.cmake index 0f37eb7e98f053f519a1fb0d6a2829b33e82e45a..c2df1dfa57fd16d9f52f2b364ab69a808709cc80 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -288,6 +288,8 @@ if(NOT FOUND_MATCH) set(WHILE_OP ON) set(WRITE_TO_ARRAY_OP ON) set(READ_FROM_ARRAY_OP ON) + set(IS_EMPTY_OP ON) + set(INCREMENT_OP ON) set(ANCHOR_GENERATOR_OP ON) set(PROPOSAL_OP ON) set(PSROI_POOL_OP ON) @@ -575,6 +577,12 @@ endif() if (READ_FROM_ARRAY_OP) add_definitions(-DREAD_FROM_ARRAY_OP) endif() +if (IS_EMPTY_OP) + add_definitions(-DIS_EMPTY_OP) +endif() +if (INCREMENT_OP) + add_definitions(-DINCREMENT_OP) +endif() if (ANCHOR_GENERATOR_OP) add_definitions(-DANCHOR_GENERATOR_OP)