From 4ae166a52aaa8344ac87678d24a10d8de329b26d Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 13 Dec 2018 12:40:41 +0800 Subject: [PATCH] Optimize gru kerenl, thanks to smilejames --- src/common/enforce.h | 2 +- src/common/types.h | 18 +- src/io/paddle_mobile.cpp | 6 +- .../kernel/arm/dequantize_bn_kernel.cpp | 20 +- src/operators/kernel/arm/quantize_kernel.cpp | 96 +++- src/operators/kernel/arm/relu_kernel.cpp | 4 +- .../kernel/central-arm-func/pool_arm_func.h | 16 +- src/operators/math/activation.h | 71 ++- src/operators/math/activation_functions.h | 92 ---- src/operators/math/gemm.cpp | 495 +++++++++--------- src/operators/math/gemm.h | 35 +- src/operators/math/gru_compute.cpp | 11 +- src/operators/math/gru_compute.h | 2 +- src/operators/math/gru_cpu_kernel.h | 161 ++++-- src/operators/math/gru_kernel.h | 51 -- src/operators/math/pooling.cpp | 4 +- src/operators/math/pooling.h | 26 +- src/operators/math/pooling3x3.cpp | 8 +- test/operators/test_pool_op.cpp | 98 ++-- 19 files changed, 641 insertions(+), 575 deletions(-) delete mode 100644 src/operators/math/activation_functions.h delete mode 100644 src/operators/math/gru_kernel.h diff --git a/src/common/enforce.h b/src/common/enforce.h index bf21b5b9a2..1bacfb88d3 100644 --- a/src/common/enforce.h +++ b/src/common/enforce.h @@ -16,9 +16,9 @@ limitations under the License. */ #ifdef ENABLE_EXCEPTION #include +#include #include #include - #endif namespace paddle_mobile { diff --git a/src/common/types.h b/src/common/types.h index c607efb9a2..ee36250ea4 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -93,18 +93,18 @@ enum RoundType { }; enum ActivationType { - Linear = 0, - Relu = 1, - Relu6 = 2, - PRelu = 3, - LeakyRelu = 4, - Tanh = 5, - Sigmoid = 6, + IDENTITY = 0, + RELU = 1, + RELU6 = 2, + PRELU = 3, + LEAKY_RELU = 4, + TANH = 5, + SIGMOID = 6, }; enum PoolingType { - Max = 0, - Avg = 1, + MAX = 0, + AVG = 1, }; extern const char *G_OP_TYPE_CONV; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index cb70514687..5f724ce4a2 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -143,12 +143,10 @@ double PaddleMobile::GetPredictTime() { int t1 = 1; int t2 = 1; for (int i = 0; i < m * k; ++i) { - unsigned int seed = 100; - a[i] = t1 + rand_r(&seed) % t2; + a[i] = t1 + rand() % t2; // NOLINT } for (int i = 0; i < k * n; ++i) { - unsigned int seed = 200; - b[i] = t1 + rand_r(&seed) % t2; + b[i] = t1 + rand() % t2; // NOLINT } paddle_mobile::operators::math::Gemm gemm; auto time1 = paddle_mobile::time(); diff --git a/src/operators/kernel/arm/dequantize_bn_kernel.cpp b/src/operators/kernel/arm/dequantize_bn_kernel.cpp index d4b41b8b87..4fa00f3a37 100644 --- a/src/operators/kernel/arm/dequantize_bn_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_bn_kernel.cpp @@ -131,7 +131,7 @@ bool FusionDequantBNKernel::Init(FusionDequantBNParam *param) { template <> void FusionDequantBNKernel::Compute( const FusionDequantBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_BN_OP @@ -146,7 +146,7 @@ bool FusionDequantBNReluKernel::Init( template <> void FusionDequantBNReluKernel::Compute( const FusionDequantBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_BN_RELU_OP @@ -162,7 +162,7 @@ bool FusionDequantAddBNKernel::Init( template <> void FusionDequantAddBNKernel::Compute( const FusionDequantAddBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_ADD_BN_OP @@ -178,7 +178,7 @@ bool FusionDequantAddBNReluKernel::Init( template <> void FusionDequantAddBNReluKernel::Compute( const FusionDequantAddBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_ADD_BN_RELU_OP @@ -292,13 +292,13 @@ void FusionDequantAddBNQuantKernel::Compute( const FusionDequantAddBNQuantParam ¶m) { switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_TOWARDS_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_AWAY_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; default: LOG(kLOG_ERROR) << "round type is not supported."; @@ -321,13 +321,13 @@ void FusionDequantAddBNReluQuantKernel::Compute( const FusionDequantAddBNQuantParam ¶m) { switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_TOWARDS_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_AWAY_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; default: LOG(kLOG_ERROR) << "round type is not supported."; diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index f719132636..1f76e604e7 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -34,14 +34,66 @@ inline float32_t vmaxvq_f32(float32x4_t r) { #endif template -static void Quantize(const Tensor *input, const float scale, Tensor *output) { +inline void QuantizeOffline(const Tensor *input, const float scale, + const float max_abs, Tensor *output) { const float *x = input->data(); int8_t *y = output->mutable_data(); size_t remain = input->numel(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = remain >> 4; remain = remain & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __postive_max = vdupq_n_f32(max_abs); + float32x4_t __negtive_max = vdupq_n_f32(-max_abs); + #pragma omp parallel for + for (size_t i = 0; i < loop; ++i) { + const float *local_x = x + (i << 4); + int8_t *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); + r0 = vmaxq_f32(vminq_f32(r0, __postive_max), __negtive_max); + r1 = vmaxq_f32(vminq_f32(r1, __postive_max), __negtive_max); + r2 = vmaxq_f32(vminq_f32(r2, __postive_max), __negtive_max); + r3 = vmaxq_f32(vminq_f32(r3, __postive_max), __negtive_max); + r0 = vmulq_f32(r0, __scale); + r1 = vmulq_f32(r1, __scale); + r2 = vmulq_f32(r2, __scale); + r3 = vmulq_f32(r3, __scale); + int32x4_t q0 = math::vRoundq_f32(r0); + int32x4_t q1 = math::vRoundq_f32(r1); + int32x4_t q2 = math::vRoundq_f32(r2); + int32x4_t q3 = math::vRoundq_f32(r3); + int16x4_t d0 = vmovn_s32(q0); + int16x4_t d1 = vmovn_s32(q1); + int16x4_t d2 = vmovn_s32(q2); + int16x4_t d3 = vmovn_s32(q3); + int16x8_t q5 = vcombine_s16(d0, d1); + int16x8_t q6 = vcombine_s16(d2, d3); + int8x8_t d5 = vmovn_s16(q5); + int8x8_t d6 = vmovn_s16(q6); + vst1_s8(local_y, d5); + vst1_s8(local_y + 8, d6); + } + x += (loop << 4); + y += (loop << 4); +#endif + for (size_t i = 0; i < remain; ++i) { + float x_temp = std::max(std::min(x[i], max_abs), -max_abs); + y[i] = math::Round(x_temp * scale); + } +} +template +inline void QuantizeOnline(const Tensor *input, const float scale, + Tensor *output) { + const float *x = input->data(); + int8_t *y = output->mutable_data(); + size_t remain = input->numel(); +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + size_t loop = remain >> 4; + remain = remain & 0xF; float32x4_t __scale = vdupq_n_f32(scale); #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { @@ -78,6 +130,17 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) { } } +template +static void Quantize(const Tensor *input, const float max_abs, + const bool offline, Tensor *output) { + float scale = 127.f / max_abs; + if (offline) { + QuantizeOffline(input, scale, max_abs, output); + } else { + QuantizeOnline(input, scale, output); + } +} + float find_abs_max(const Tensor *input) { float max_abs = 0.f; const float *x = input->data(); @@ -133,23 +196,22 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { max_abs = find_abs_max(input); } max_abs = std::max(max_abs, 1e-6f); - // only support int8 currently - float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; - switch (param.round_type_) { - case ROUND_NEAREST_TO_EVEN: - Quantize(input, scale, output); - break; - case ROUND_NEAREST_TOWARDS_ZERO: - Quantize(input, scale, output); - break; - case ROUND_NEAREST_AWAY_ZERO: - Quantize(input, scale, output); - break; - default: - LOG(kLOG_ERROR) << "round type is not supported."; - break; - } + // switch (param.round_type_) { + // case ROUND_NEAREST_TO_EVEN: + // Quantize(input, scale, output); + // break; + // case ROUND_NEAREST_TOWARDS_ZERO: + // Quantize(input, scale, output); + // break; + // case ROUND_NEAREST_AWAY_ZERO: + // Quantize(input, scale, output); + // break; + // default: + // LOG(kLOG_ERROR) << "round type is not supported."; + // break; + // } + Quantize(input, max_abs, param.offline_, output); } } // namespace operators diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/relu_kernel.cpp index 979a1d6e2a..0333e9db44 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/relu_kernel.cpp @@ -74,7 +74,7 @@ template <> void ReluKernel::Compute(const ReluParam ¶m) { const Tensor *input = param.InputX(); Tensor *output = param.Out(); - ReluCompute()(input, output); + ReluCompute()(input, output); } template <> @@ -86,7 +86,7 @@ template <> void Relu6Kernel::Compute(const ReluParam ¶m) { const Tensor *input = param.InputX(); Tensor *output = param.Out(); - ReluCompute()(input, output); + ReluCompute()(input, output); } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 529798dd80..757d64480f 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -40,28 +40,28 @@ void PoolCompute(const PoolParam ¶m) { if (ksize[0] == 3 && ksize[0] == ksize[1]) { if (pooling_type == "max" && strides[0] == strides[1]) { if (strides[0] == 1) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else if (strides[0] == 2) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } } else if (pooling_type == "avg" && strides[0] == strides[1]) { if (strides[0] == 1) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else if (strides[0] == 2) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } } else { // Others } } else { if (pooling_type == "max") { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } else if (pooling_type == "avg") { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } else { // Others } diff --git a/src/operators/math/activation.h b/src/operators/math/activation.h index 1274f0fd8a..51ce378978 100644 --- a/src/operators/math/activation.h +++ b/src/operators/math/activation.h @@ -16,50 +16,109 @@ limitations under the License. */ #include #include +#include +#include "common/enforce.h" #include "common/types.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include +#include "operators/math/math_func_neon.h" #endif namespace paddle_mobile { namespace operators { namespace math { +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +inline ActivationType GetActivationType(const std::string &type) { + if (type == "sigmoid") { + return ActivationType::SIGMOID; + } else if (type == "relu") { + return ActivationType::RELU; + } else if (type == "tanh") { + return ActivationType::TANH; + } else if (type == "identity" || type == "") { + return ActivationType::IDENTITY; + } + PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); +} + #if defined(__ARM_NEON__) || defined(__ARM_NEON) -template +template inline float32x4_t vActiveq_f32(const float32x4_t &x) { return x; } template <> -inline float32x4_t vActiveq_f32(const float32x4_t &x) { +inline float32x4_t vActiveq_f32(const float32x4_t &x) { float32x4_t __zero = vdupq_n_f32(0.f); return vmaxq_f32(x, __zero); } template <> -inline float32x4_t vActiveq_f32(const float32x4_t &x) { +inline float32x4_t vActiveq_f32(const float32x4_t &x) { float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __six = vdupq_n_f32(6.f); return vminq_f32(vmaxq_f32(x, __zero), __six); } + +template <> +inline float32x4_t vActiveq_f32(const float32x4_t &x) { + float32x4_t __one = vdupq_n_f32(1.f); + float32x4_t __x = vnegq_f32(x); + __x = exp_ps(__x); + __x = vaddq_f32(__x, __one); + float32x4_t __out = vrecpeq_f32(__x); + return vmulq_f32(vrecpsq_f32(__x, __out), __out); +} + +template <> +inline float32x4_t vActiveq_f32(const float32x4_t &x) { + float32x4_t __one = vdupq_n_f32(1.f); + float32x4_t __x = vnegq_f32(x); + __x = vmulq_n_f32(__x, 2.f); + __x = exp_ps(__x); + __x = vaddq_f32(__x, __one); + float32x4_t __out = vrecpeq_f32(__x); + __out = vmulq_f32(vrecpsq_f32(__x, __out), __out); + __out = vmulq_n_f32(__out, 2.f); + return vsubq_f32(__out, __one); +} #endif -template +template inline float Active(const float &x) { return x; } template <> -inline float Active(const float &x) { +inline float Active(const float &x) { return std::max(x, 0.f); } template <> -inline float Active(const float &x) { +inline float Active(const float &x) { return std::min(std::max(x, 0.f), 6.f); } +template <> +inline float Active(const float &x) { + // float tmp = x > SIGMOID_THRESHOLD_MAX ? SIGMOID_THRESHOLD_MAX : x; + // tmp = x > SIGMOID_THRESHOLD_MIN ? x : SIGMOID_THRESHOLD_MIN; + // return 1.f / (1.f + exp(-tmp)); + return 1.f / (1.f + exp(-x)); +} + +template <> +inline float Active(const float &x) { + // float tmp = -2.f * x; + // tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + // return (2.f / (1.f + exp(tmp))) - 1.f; + return 2.f / (1.f + exp(-2.f * x)) - 1.f; +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/activation_functions.h b/src/operators/math/activation_functions.h deleted file mode 100644 index 8604065a25..0000000000 --- a/src/operators/math/activation_functions.h +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "common/enforce.h" -namespace paddle_mobile { -namespace operators { -namespace math { - -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 - -enum ActivationType { - kSigmoid, - kReLU, - kTanh, - kIdentity, -}; - -inline ActivationType GetActivationType(const std::string &type) { - if (type == "sigmoid") { - return ActivationType::kSigmoid; - } else if (type == "relu") { - return ActivationType::kReLU; - } else if (type == "tanh") { - return ActivationType::kTanh; - } else if (type == "identity" || type == "") { - return ActivationType::kIdentity; - } - PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); -} - -namespace forward { - -template -T Identity(const T a) { - return a; -} - -template -T Relu(const T a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -template -T Sigmoid(const T a) { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - T tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -template -T Tanh(const T a) { - T tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -} // namespace forward - -template -struct Active { - typedef T (*Act)(T); -}; - -static Active::Act kActFloat[] = { - &forward::Sigmoid, &forward::Relu, &forward::Tanh, - &forward::Identity}; - -namespace forward { -inline float activation(float a, int index) { return kActFloat[index](a); } - -} // namespace forward - -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index c17b2a5e4d..69c09e255b 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -1260,10 +1260,10 @@ void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { "q10", "q11", "q12", "q13"); } -/* -void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, int -lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float -*bufferC = static_cast(memory::Alloc(sizeof(float) * n)); +void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, float *C, + int ldc, bool relu) { + float *bufferC = static_cast(memory::Alloc(sizeof(float) * n)); const float *a0, *b0, *b1, *b2, *b3; float *c0, *C0; @@ -1482,6 +1482,7 @@ lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float } } +/* void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *new_scale, float *new_bias) { @@ -2579,278 +2580,278 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, } } - /* - // C = A * B - void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); +// C = A * B +void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vst1.32 {q0, q1}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vst1.32 {q0, q1}, [%[C]]! \n\t" - "vld1.32 {q2, q3}, [%[c]]! \n\t" - "vst1.32 {q2, q3}, [%[C]]! \n\t" + "vld1.32 {q2, q3}, [%[c]]! \n\t" + "vst1.32 {q2, q3}, [%[C]]! \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - "vld1.32 {q4}, [%[c]]! \n\t" - "vst1.32 {q4}, [%[C]]! \n\t" + "vld1.32 {q4}, [%[c]]! \n\t" + "vst1.32 {q4}, [%[C]]! \n\t" - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q5}, [%[c]]! \n\t" - "vst1.32 {q5}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + "vld1.32 {q5}, [%[c]]! \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); - } + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +} - // C = alpha * A * B + beta * C - void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} +// C = alpha * A * B + beta * C +void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} - // C = A * B + C - void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; +// C = A * B + C +void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[C]] \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[C]] \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[C]] \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[C]] \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - : [C] "+r"(C), [c] "+r"(c) - : [nc1] "r"(nc1) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", - "q11", "q12", "q13"); + : [C] "+r"(C), [c] "+r"(c) + : [nc1] "r"(nc1) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); - if (_nc1 != 0) { - for (int j = 0; j < _nc1; j++) { - *C++ += *c++; - } + if (_nc1 != 0) { + for (int j = 0; j < _nc1; j++) { + *C++ += *c++; } } +} - // C = A * B + C, relu(C) - void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; +// C = A * B + C, relu(C) +void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[C]] \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[C]] \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[C]] \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[C]] \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - : [C] "+r"(C), [c] "+r"(c) - : [nc1] "r"(nc1) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", - "q11", "q12", "q13"); + : [C] "+r"(C), [c] "+r"(c) + : [nc1] "r"(nc1) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); - if (_nc1 != 0) { - for (int j = 0; j < _nc1; j++) { - *C += *c; - if (*C < 0) { - *C = 0; - } - C++; - c++; + if (_nc1 != 0) { + for (int j = 0; j < _nc1; j++) { + *C += *c; + if (*C < 0) { + *C = 0; } + C++; + c++; } } +} - // C = A * B, batchnorm(C) - void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, - float *bias) { - int nc1 = n / 16; - int _nc1 = n % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); - - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" - - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] - "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", - "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13"); - } - - // C = A * B, batchnorm(C), relu(C) - void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float - *scale, float *bias) { int nc1 = n / 16; int _nc1 = n % 16; int nc2 = _nc1 / - 4; int nc3 = 16 - 4 * (_nc1 % 4); - - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" + /* + // C = A * B, batchnorm(C) + void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, + float *bias) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[scale]]! \n\t" + "vld1.32 {q10, q11}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q2 \n\t" + "vmla.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" + + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[scale]]! \n\t" + "vld1.32 {q12, q13}, [%[bias]]! \n\t" + "vmla.f32 q12, q4, q6 \n\t" + "vmla.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[scale], %[scale], %[nc3] \n\t" + "sub %[bias], %[bias], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] + "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", + "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13"); + } + + // C = A * B, batchnorm(C), relu(C) + void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float + *scale, float *bias) { int nc1 = n / 16; int _nc1 = n % 16; int nc2 = _nc1 / + 4; int nc3 = 16 - 4 * (_nc1 % 4); - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] - "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", - "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14"); - } - */ + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[scale]]! \n\t" + "vld1.32 {q10, q11}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q2 \n\t" + "vmla.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" + + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[scale]]! \n\t" + "vld1.32 {q12, q13}, [%[bias]]! \n\t" + "vmla.f32 q12, q4, q6 \n\t" + "vmla.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[scale], %[scale], %[nc3] \n\t" + "sub %[bias], %[bias], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] + "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", + "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14"); + } + */ #endif // __aarch64__ #else @@ -3149,13 +3150,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { + if (m == 1 && bias == nullptr) { + return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu); + } #ifdef _OPENMP int max_threads = omp_get_max_threads(); #else int max_threads = 1; #endif - int L1 = 64 / max_threads * 1024; + // int L1 = 64 / max_threads * 1024; + int L1 = 32 / max_threads * 1024; KC = k; zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); memset(static_cast(zero), 0, sizeof(float) * KC); diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index fb2c248c9b..99c68de7c3 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -105,16 +105,15 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); - /* // 向量矩阵乘法 (M = 1) void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu); - - void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, float - *C, int ldc, bool relu, float *new_scale, float *new_bias); - */ + /* + void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, float + *C, int ldc, bool relu, float *new_scale, float *new_bias); + */ // 计算一个更小的 C 矩阵分块 void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); @@ -149,7 +148,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias, float *bias1); - /* // 向量矩阵乘法结果回写 // C = A * B void VecWriteBasic(int n, float *c, float *C, int ldc); @@ -159,13 +157,14 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void VecWriteWithAdd(int n, float *c, float *C, int ldc); // C = A * B + C, relu(C) void VecWriteWithAddRelu(int n, float *c, float *C, int ldc); - // C = A * B, batchnorm(C) - void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, - float *new_bias); - // C = A * B, batchnorm(C), relu(C) - void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, - float *new_bias); - */ + /* + // C = A * B, batchnorm(C) + void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, + float *new_bias); + // C = A * B, batchnorm(C), relu(C) + void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float + *new_scale, float *new_bias); + */ // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, @@ -392,7 +391,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); #endif @@ -414,7 +413,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); #endif @@ -438,7 +437,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, int8_t *local_A = packedA_int8 + MC * KC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); #endif @@ -468,7 +467,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, int8_t *local_B = packedB_int8 + KC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); #endif diff --git a/src/operators/math/gru_compute.cpp b/src/operators/math/gru_compute.cpp index bbf1b01a21..19c7a2685c 100644 --- a/src/operators/math/gru_compute.cpp +++ b/src/operators/math/gru_compute.cpp @@ -11,13 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef GRU_OP + #include "operators/math/gru_compute.h" #include "common/types.h" -#include "operators/math/activation_functions.h" +#include "operators/math/activation.h" #include "operators/math/gemm.h" #include "operators/math/gru_cpu_kernel.h" -#include "operators/math/gru_kernel.h" namespace paddle_mobile { namespace operators { @@ -43,8 +44,7 @@ struct GRUUnitFunctor { #endif } - forward_reset_output(forward::gru_resetOutput(), value, frame_size, - batch_size, active_gate); + forward_reset_output(value, frame_size, batch_size, active_gate); if (value.prev_out_value) { #ifdef _OPENMP @@ -60,8 +60,7 @@ struct GRUUnitFunctor { #endif } - forward_final_output(forward::gru_finalOutput(), value, frame_size, - batch_size, active_node); + forward_final_output(value, frame_size, batch_size, active_node); } }; diff --git a/src/operators/math/gru_compute.h b/src/operators/math/gru_compute.h index 89cac1b8e4..00f4da9022 100644 --- a/src/operators/math/gru_compute.h +++ b/src/operators/math/gru_compute.h @@ -11,7 +11,7 @@ limitations under the License. */ #ifdef GRU_OP #pragma once -#include "operators/math/activation_functions.h" +#include "operators/math/activation.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/math/gru_cpu_kernel.h b/src/operators/math/gru_cpu_kernel.h index ea24c4f1d9..a010fb616b 100644 --- a/src/operators/math/gru_cpu_kernel.h +++ b/src/operators/math/gru_cpu_kernel.h @@ -11,21 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef GRU_OP + #pragma once + #include -#include "operators/math/activation_functions.h" +#include "operators/math/activation.h" #include "operators/math/gru_compute.h" namespace paddle_mobile { namespace operators { namespace math { -template -void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, - T *gate_value, T *reset_output_value, - T *prev_output_value, int frame_size, - ActivationType active_gate) { +template +void hl_naive_gru_forward_reset_output(T *gate_value, T *reset_output_value, + T *prev_output_value, int frame_size) { T r_value_update_gate; T r_value_reset_gate; T r_value_reset_output; @@ -33,27 +34,57 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, T *update_gate = gate_value; T *reset_gate = gate_value + frame_size; - for (int i = 0; i < frame_size; i++) { + int remain = frame_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = remain >> 3; + remain = remain & 0x7; + float32x4_t prev0 = vdupq_n_f32(0.f); + float32x4_t prev1 = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i) { + float32x4_t update0 = vld1q_f32(update_gate); + float32x4_t update1 = vld1q_f32(update_gate + 4); + float32x4_t reset0 = vld1q_f32(reset_gate); + float32x4_t reset1 = vld1q_f32(reset_gate + 4); + if (prev_output_value) { + prev0 = vld1q_f32(prev_output_value); + prev1 = vld1q_f32(prev_output_value + 4); + prev_output_value += 8; + } + update0 = vActiveq_f32(update0); + update1 = vActiveq_f32(update1); + reset0 = vActiveq_f32(reset0); + reset1 = vActiveq_f32(reset1); + float32x4_t output0 = vmulq_f32(prev0, reset0); + float32x4_t output1 = vmulq_f32(prev1, reset1); + vst1q_f32(update_gate, update0); + vst1q_f32(update_gate + 4, update1); + vst1q_f32(reset_gate, reset0); + vst1q_f32(reset_gate + 4, reset1); + vst1q_f32(reset_output_value, output0); + vst1q_f32(reset_output_value + 4, output1); + update_gate += 8; + reset_gate += 8; + reset_output_value += 8; + } +#endif // __ARM_NEON__ + for (int i = 0; i < remain; i++) { r_value_update_gate = update_gate[i]; r_value_reset_gate = reset_gate[i]; if (prev_output_value) { r_prev_out = prev_output_value[i]; } - - op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate); - + r_value_update_gate = Active(r_value_update_gate); + r_value_reset_gate = Active(r_value_reset_gate); + r_value_reset_output = r_prev_out * r_value_reset_gate; update_gate[i] = r_value_update_gate; reset_gate[i] = r_value_reset_gate; reset_output_value[i] = r_value_reset_output; } } -template -void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, - T *gate_value, T *prev_output_value, - T *output_value, int frame_size, - ActivationType active_node) { +template +void hl_naive_gru_forward_final_output(T *gate_value, T *prev_output_value, + T *output_value, int frame_size) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -61,30 +92,73 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, T *update_gate = gate_value; T *frame_state = gate_value + frame_size * 2; - for (int i = 0; i < frame_size; i++) { + int remain = frame_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = remain >> 3; + remain = remain & 0x7; + float32x4_t prev0 = vdupq_n_f32(0.f); + float32x4_t prev1 = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i) { + float32x4_t update0 = vld1q_f32(update_gate); + float32x4_t update1 = vld1q_f32(update_gate + 4); + float32x4_t state0 = vld1q_f32(frame_state); + float32x4_t state1 = vld1q_f32(frame_state + 4); + if (prev_output_value) { + prev0 = vld1q_f32(prev_output_value); + prev1 = vld1q_f32(prev_output_value + 4); + prev_output_value += 8; + } + state0 = vActiveq_f32(state0); + state1 = vActiveq_f32(state1); + float32x4_t output0 = vmlsq_f32(prev0, update0, prev0); + float32x4_t output1 = vmlsq_f32(prev1, update1, prev1); + output0 = vmlaq_f32(output0, update0, state0); + output1 = vmlaq_f32(output1, update1, state1); + vst1q_f32(frame_state, state0); + vst1q_f32(frame_state + 4, state1); + vst1q_f32(output_value, output0); + vst1q_f32(output_value + 4, output1); + update_gate += 8; + frame_state += 8; + output_value += 8; + } +#endif // __ARM_NEON__ + for (int i = 0; i < remain; i++) { r_value_update_gate = update_gate[i]; r_value_frame_state = frame_state[i]; if (prev_output_value) { r_prev_out = prev_output_value[i]; } - - op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); - + r_value_frame_state = Active(r_value_frame_state); + r_output = r_prev_out - r_value_update_gate * r_prev_out + + r_value_update_gate * r_value_frame_state; frame_state[i] = r_value_frame_state; output_value[i] = r_output; } } -template -inline void forward_reset_output(OpResetOutput op_reset_output, - GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_gate) { - for (int b = 0; b < batch_size; b++) { - hl_naive_gru_forward_reset_output( - op_reset_output, value.gate_value, value.reset_output_value, - value.prev_out_value, frame_size, active_gate); +#define FORWARD_RESET_OUTPUT(active_type, value, frame_size) \ + hl_naive_gru_forward_reset_output( \ + value.gate_value, value.reset_output_value, value.prev_out_value, \ + frame_size); +template +inline void forward_reset_output(GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { + for (int b = 0; b < batch_size; ++b) { + switch (active_node) { + case RELU: + FORWARD_RESET_OUTPUT(RELU, value, frame_size); + break; + case SIGMOID: + FORWARD_RESET_OUTPUT(SIGMOID, value, frame_size); + break; + case TANH: + FORWARD_RESET_OUTPUT(TANH, value, frame_size); + break; + default: + FORWARD_RESET_OUTPUT(IDENTITY, value, frame_size); + } value.gate_value += frame_size * 3; value.reset_output_value += frame_size; if (value.prev_out_value) { @@ -93,15 +167,27 @@ inline void forward_reset_output(OpResetOutput op_reset_output, } } -template -inline void forward_final_output(OpFinalOutput op_final_output, - GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_node) { - for (int b = 0; b < batch_size; b++) { - hl_naive_gru_forward_final_output(op_final_output, value.gate_value, - value.prev_out_value, value.output_value, - frame_size, active_node); +#define FORWARD_FINAL_OUTPUT(active_type, value, frame_size) \ + hl_naive_gru_forward_final_output( \ + value.gate_value, value.prev_out_value, value.output_value, frame_size) +template +inline void forward_final_output(GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { + for (int b = 0; b < batch_size; ++b) { + switch (active_node) { + case RELU: + FORWARD_FINAL_OUTPUT(RELU, value, frame_size); + break; + case SIGMOID: + FORWARD_FINAL_OUTPUT(SIGMOID, value, frame_size); + break; + case TANH: + FORWARD_FINAL_OUTPUT(TANH, value, frame_size); + break; + default: + FORWARD_FINAL_OUTPUT(IDENTITY, value, frame_size); + } value.gate_value += frame_size * 3; value.output_value += frame_size; if (value.prev_out_value) { @@ -113,4 +199,5 @@ inline void forward_final_output(OpFinalOutput op_final_output, } // namespace math } // namespace operators } // namespace paddle_mobile + #endif diff --git a/src/operators/math/gru_kernel.h b/src/operators/math/gru_kernel.h deleted file mode 100644 index 6113ce8da9..0000000000 --- a/src/operators/math/gru_kernel.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef GRU_OP -#pragma once -#include -#include "operators/math/activation_functions.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -namespace forward { - -template -class gru_resetOutput { - public: - void operator()(T *value_update_gate, T *value_reset_gate, T *prev_out, - T *value_reset_output, ActivationType act_gate) { - *value_update_gate = activation(*value_update_gate, act_gate); - *value_reset_gate = activation(*value_reset_gate, act_gate); - *value_reset_output = (*prev_out) * (*value_reset_gate); - } -}; - -template -class gru_finalOutput { - public: - void operator()(T *value_update_gate, T *value_frame_state, T *prev_out, - T *value_output, ActivationType act_input) { - *value_frame_state = activation(*value_frame_state, act_input); - *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + - ((*value_update_gate) * (*value_frame_state)); - } -}; -} // namespace forward - -} // namespace math -} // namespace operators -} // namespace paddle_mobile -#endif diff --git a/src/operators/math/pooling.cpp b/src/operators/math/pooling.cpp index 1270e6a898..46b4453e73 100644 --- a/src/operators/math/pooling.cpp +++ b/src/operators/math/pooling.cpp @@ -72,8 +72,8 @@ void Pooling

::operator()(const framework::Tensor &input, } } -template struct Pooling; -template struct Pooling; +template struct Pooling; +template struct Pooling; } // namespace math } // namespace operators diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h index 4a3b2b8389..4239cf8cbc 100644 --- a/src/operators/math/pooling.h +++ b/src/operators/math/pooling.h @@ -30,7 +30,7 @@ namespace paddle_mobile { namespace operators { namespace math { -template +template struct PoolingVal { float val; int count; @@ -44,11 +44,11 @@ struct PoolingVal { }; template <> -struct PoolingVal { +struct PoolingVal { float val; int count; PoolingVal() : val(0.f), count(0) {} - inline PoolingVal &operator+=(const float &x) { + inline PoolingVal &operator+=(const float &x) { val += x; ++count; return *this; @@ -57,57 +57,57 @@ struct PoolingVal { }; #if defined(__ARM_NEON) || defined(__ARM_NEON__) -template +template inline float32x4_t vPoolInitq_f32() { return vdupq_n_f32(-std::numeric_limits::max()); } template <> -inline float32x4_t vPoolInitq_f32() { +inline float32x4_t vPoolInitq_f32() { return vdupq_n_f32(0.f); } -template +template inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { return vmaxq_f32(x1, x2); } template <> -inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, +inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { return vaddq_f32(x1, x2); } -template +template inline float32x4_t vPoolPostq_f32(const float32x4_t &x, const float32x4_t &post) { return x; } template <> -inline float32x4_t vPoolPostq_f32(const float32x4_t &x, +inline float32x4_t vPoolPostq_f32(const float32x4_t &x, const float32x4_t &post) { return vmulq_f32(x, post); } #endif // __ARM_NEON__ -template +template inline float PoolPre(const float &x1, const float &x2) { return std::max(x1, x2); } template <> -inline float PoolPre(const float &x1, const float &x2) { +inline float PoolPre(const float &x1, const float &x2) { return x1 + x2; } -template +template inline float PoolPost(const float &x, const float &post) { return x; } template <> -inline float PoolPost(const float &x, const float &post) { +inline float PoolPost(const float &x, const float &post) { return x * post; } diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp index e556768ce0..72ffb6161a 100644 --- a/src/operators/math/pooling3x3.cpp +++ b/src/operators/math/pooling3x3.cpp @@ -1016,10 +1016,10 @@ struct Pooling3x3 { } }; -template struct Pooling3x3; -template struct Pooling3x3; -template struct Pooling3x3; -template struct Pooling3x3; +template struct Pooling3x3; +template struct Pooling3x3; +template struct Pooling3x3; +template struct Pooling3x3; } // namespace math } // namespace operators diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index 3668b8cb28..5d3c4374a4 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -74,11 +74,11 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { output_cmp.mutable_data(output->dims()); if (pooling_type == "avg") { - math::Pooling()(*input, std::vector{kernel_h, kernel_w}, + math::Pooling()(*input, std::vector{kernel_h, kernel_w}, std::vector{stride_h, stride_w}, std::vector{pad_h, pad_w}, &output_cmp); } else { - math::Pooling()(*input, std::vector{kernel_h, kernel_w}, + math::Pooling()(*input, std::vector{kernel_h, kernel_w}, std::vector{stride_h, stride_w}, std::vector{pad_h, pad_w}, &output_cmp); } @@ -117,57 +117,57 @@ int main(int argc, char *argv[]) { int in_channels = atoi(argv[1]); int in_height = atoi(argv[2]); int in_width = atoi(argv[3]); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=0, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=1, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=2, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=5, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=0, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=1, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=2, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=5, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=3, pad=0, stride=2"; paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=1, stride=2"; - // paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=2, stride=2"; - // paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=5, stride=2"; - // paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=1, stride=2"; + paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=2, stride=2"; + paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=5, stride=2"; + paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); // // kernel = 5, pad = 0, stride = 1 // LOG(paddle_mobile::kLOG_INFO) -- GitLab