diff --git a/src/common/enforce.h b/src/common/enforce.h index bf21b5b9a2fe5f70b3bd23a581f0c1dfbf373f42..1bacfb88d328c85de9b284249c8d9d58e7fc8e5e 100644 --- a/src/common/enforce.h +++ b/src/common/enforce.h @@ -16,9 +16,9 @@ limitations under the License. */ #ifdef ENABLE_EXCEPTION #include +#include #include #include - #endif namespace paddle_mobile { diff --git a/src/common/types.h b/src/common/types.h index c607efb9a2636cc09fd2ac7444a117e0f401251d..ee36250ea489da3facdb025594fc95e77e6693bd 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -93,18 +93,18 @@ enum RoundType { }; enum ActivationType { - Linear = 0, - Relu = 1, - Relu6 = 2, - PRelu = 3, - LeakyRelu = 4, - Tanh = 5, - Sigmoid = 6, + IDENTITY = 0, + RELU = 1, + RELU6 = 2, + PRELU = 3, + LEAKY_RELU = 4, + TANH = 5, + SIGMOID = 6, }; enum PoolingType { - Max = 0, - Avg = 1, + MAX = 0, + AVG = 1, }; extern const char *G_OP_TYPE_CONV; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index cb7051468715179e1d9a5ead407941a20d9cb87a..5f724ce4a2772de8ad1ac501c670ae3fe57a337b 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -143,12 +143,10 @@ double PaddleMobile::GetPredictTime() { int t1 = 1; int t2 = 1; for (int i = 0; i < m * k; ++i) { - unsigned int seed = 100; - a[i] = t1 + rand_r(&seed) % t2; + a[i] = t1 + rand() % t2; // NOLINT } for (int i = 0; i < k * n; ++i) { - unsigned int seed = 200; - b[i] = t1 + rand_r(&seed) % t2; + b[i] = t1 + rand() % t2; // NOLINT } paddle_mobile::operators::math::Gemm gemm; auto time1 = paddle_mobile::time(); diff --git a/src/operators/kernel/arm/dequantize_bn_kernel.cpp b/src/operators/kernel/arm/dequantize_bn_kernel.cpp index d4b41b8b872ebe4fa21fd78f20a74a57d7ed0bc1..4fa00f3a378e7f715c8435ab56ddc81e6124d39a 100644 --- a/src/operators/kernel/arm/dequantize_bn_kernel.cpp +++ b/src/operators/kernel/arm/dequantize_bn_kernel.cpp @@ -131,7 +131,7 @@ bool FusionDequantBNKernel::Init(FusionDequantBNParam *param) { template <> void FusionDequantBNKernel::Compute( const FusionDequantBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_BN_OP @@ -146,7 +146,7 @@ bool FusionDequantBNReluKernel::Init( template <> void FusionDequantBNReluKernel::Compute( const FusionDequantBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_BN_RELU_OP @@ -162,7 +162,7 @@ bool FusionDequantAddBNKernel::Init( template <> void FusionDequantAddBNKernel::Compute( const FusionDequantAddBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_ADD_BN_OP @@ -178,7 +178,7 @@ bool FusionDequantAddBNReluKernel::Init( template <> void FusionDequantAddBNReluKernel::Compute( const FusionDequantAddBNParam ¶m) { - DequantBNCompute(¶m); + DequantBNCompute(¶m); } #endif // FUSION_DEQUANT_ADD_BN_RELU_OP @@ -292,13 +292,13 @@ void FusionDequantAddBNQuantKernel::Compute( const FusionDequantAddBNQuantParam ¶m) { switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_TOWARDS_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_AWAY_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; default: LOG(kLOG_ERROR) << "round type is not supported."; @@ -321,13 +321,13 @@ void FusionDequantAddBNReluQuantKernel::Compute( const FusionDequantAddBNQuantParam ¶m) { switch (param.round_type_) { case ROUND_NEAREST_TO_EVEN: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_TOWARDS_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; case ROUND_NEAREST_AWAY_ZERO: - DequantBNQuantCompute(¶m); + DequantBNQuantCompute(¶m); break; default: LOG(kLOG_ERROR) << "round type is not supported."; diff --git a/src/operators/kernel/arm/quantize_kernel.cpp b/src/operators/kernel/arm/quantize_kernel.cpp index f7191326368f3d2b3036443c53efc5cfb27332fb..1f76e604e70e69aae6bb4bd78662ffff1c38162c 100644 --- a/src/operators/kernel/arm/quantize_kernel.cpp +++ b/src/operators/kernel/arm/quantize_kernel.cpp @@ -34,14 +34,66 @@ inline float32_t vmaxvq_f32(float32x4_t r) { #endif template -static void Quantize(const Tensor *input, const float scale, Tensor *output) { +inline void QuantizeOffline(const Tensor *input, const float scale, + const float max_abs, Tensor *output) { const float *x = input->data(); int8_t *y = output->mutable_data(); size_t remain = input->numel(); #if defined(__ARM_NEON__) || defined(__ARM_NEON) size_t loop = remain >> 4; remain = remain & 0xF; + float32x4_t __scale = vdupq_n_f32(scale); + float32x4_t __postive_max = vdupq_n_f32(max_abs); + float32x4_t __negtive_max = vdupq_n_f32(-max_abs); + #pragma omp parallel for + for (size_t i = 0; i < loop; ++i) { + const float *local_x = x + (i << 4); + int8_t *local_y = y + (i << 4); + float32x4_t r0 = vld1q_f32(local_x); + float32x4_t r1 = vld1q_f32(local_x + 4); + float32x4_t r2 = vld1q_f32(local_x + 8); + float32x4_t r3 = vld1q_f32(local_x + 12); + r0 = vmaxq_f32(vminq_f32(r0, __postive_max), __negtive_max); + r1 = vmaxq_f32(vminq_f32(r1, __postive_max), __negtive_max); + r2 = vmaxq_f32(vminq_f32(r2, __postive_max), __negtive_max); + r3 = vmaxq_f32(vminq_f32(r3, __postive_max), __negtive_max); + r0 = vmulq_f32(r0, __scale); + r1 = vmulq_f32(r1, __scale); + r2 = vmulq_f32(r2, __scale); + r3 = vmulq_f32(r3, __scale); + int32x4_t q0 = math::vRoundq_f32(r0); + int32x4_t q1 = math::vRoundq_f32(r1); + int32x4_t q2 = math::vRoundq_f32(r2); + int32x4_t q3 = math::vRoundq_f32(r3); + int16x4_t d0 = vmovn_s32(q0); + int16x4_t d1 = vmovn_s32(q1); + int16x4_t d2 = vmovn_s32(q2); + int16x4_t d3 = vmovn_s32(q3); + int16x8_t q5 = vcombine_s16(d0, d1); + int16x8_t q6 = vcombine_s16(d2, d3); + int8x8_t d5 = vmovn_s16(q5); + int8x8_t d6 = vmovn_s16(q6); + vst1_s8(local_y, d5); + vst1_s8(local_y + 8, d6); + } + x += (loop << 4); + y += (loop << 4); +#endif + for (size_t i = 0; i < remain; ++i) { + float x_temp = std::max(std::min(x[i], max_abs), -max_abs); + y[i] = math::Round(x_temp * scale); + } +} +template +inline void QuantizeOnline(const Tensor *input, const float scale, + Tensor *output) { + const float *x = input->data(); + int8_t *y = output->mutable_data(); + size_t remain = input->numel(); +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + size_t loop = remain >> 4; + remain = remain & 0xF; float32x4_t __scale = vdupq_n_f32(scale); #pragma omp parallel for for (size_t i = 0; i < loop; ++i) { @@ -78,6 +130,17 @@ static void Quantize(const Tensor *input, const float scale, Tensor *output) { } } +template +static void Quantize(const Tensor *input, const float max_abs, + const bool offline, Tensor *output) { + float scale = 127.f / max_abs; + if (offline) { + QuantizeOffline(input, scale, max_abs, output); + } else { + QuantizeOnline(input, scale, output); + } +} + float find_abs_max(const Tensor *input) { float max_abs = 0.f; const float *x = input->data(); @@ -133,23 +196,22 @@ void QuantizeKernel::Compute(const QuantizeParam ¶m) { max_abs = find_abs_max(input); } max_abs = std::max(max_abs, 1e-6f); - // only support int8 currently - float scale = 127 / max_abs; param.online_scale_->mutable_data()[0] = max_abs; - switch (param.round_type_) { - case ROUND_NEAREST_TO_EVEN: - Quantize(input, scale, output); - break; - case ROUND_NEAREST_TOWARDS_ZERO: - Quantize(input, scale, output); - break; - case ROUND_NEAREST_AWAY_ZERO: - Quantize(input, scale, output); - break; - default: - LOG(kLOG_ERROR) << "round type is not supported."; - break; - } + // switch (param.round_type_) { + // case ROUND_NEAREST_TO_EVEN: + // Quantize(input, scale, output); + // break; + // case ROUND_NEAREST_TOWARDS_ZERO: + // Quantize(input, scale, output); + // break; + // case ROUND_NEAREST_AWAY_ZERO: + // Quantize(input, scale, output); + // break; + // default: + // LOG(kLOG_ERROR) << "round type is not supported."; + // break; + // } + Quantize(input, max_abs, param.offline_, output); } } // namespace operators diff --git a/src/operators/kernel/arm/relu_kernel.cpp b/src/operators/kernel/arm/relu_kernel.cpp index 979a1d6e2a958aa0b99d47b5cd022d21a3873452..0333e9db4445aa68498671ed6472a2f8ff113e1c 100644 --- a/src/operators/kernel/arm/relu_kernel.cpp +++ b/src/operators/kernel/arm/relu_kernel.cpp @@ -74,7 +74,7 @@ template <> void ReluKernel::Compute(const ReluParam ¶m) { const Tensor *input = param.InputX(); Tensor *output = param.Out(); - ReluCompute()(input, output); + ReluCompute()(input, output); } template <> @@ -86,7 +86,7 @@ template <> void Relu6Kernel::Compute(const ReluParam ¶m) { const Tensor *input = param.InputX(); Tensor *output = param.Out(); - ReluCompute()(input, output); + ReluCompute()(input, output); } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/pool_arm_func.h b/src/operators/kernel/central-arm-func/pool_arm_func.h index 529798dd809c5b12d0e7621a83f5b25bfd0afc79..757d64480fa2fba46ba599a7f5cf9aaddfa5567a 100644 --- a/src/operators/kernel/central-arm-func/pool_arm_func.h +++ b/src/operators/kernel/central-arm-func/pool_arm_func.h @@ -40,28 +40,28 @@ void PoolCompute(const PoolParam ¶m) { if (ksize[0] == 3 && ksize[0] == ksize[1]) { if (pooling_type == "max" && strides[0] == strides[1]) { if (strides[0] == 1) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else if (strides[0] == 2) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } } else if (pooling_type == "avg" && strides[0] == strides[1]) { if (strides[0] == 1) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else if (strides[0] == 2) { - math::Pooling3x3()(*input, paddings, output); + math::Pooling3x3()(*input, paddings, output); } else { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } } else { // Others } } else { if (pooling_type == "max") { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } else if (pooling_type == "avg") { - math::Pooling()(*input, ksize, strides, paddings, output); + math::Pooling()(*input, ksize, strides, paddings, output); } else { // Others } diff --git a/src/operators/math/activation.h b/src/operators/math/activation.h index 1274f0fd8a8bd3732af29d3ba260699a85dff173..51ce3789785a4e4211298e44312a87ec2de09ea0 100644 --- a/src/operators/math/activation.h +++ b/src/operators/math/activation.h @@ -16,50 +16,109 @@ limitations under the License. */ #include #include +#include +#include "common/enforce.h" #include "common/types.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include +#include "operators/math/math_func_neon.h" #endif namespace paddle_mobile { namespace operators { namespace math { +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +inline ActivationType GetActivationType(const std::string &type) { + if (type == "sigmoid") { + return ActivationType::SIGMOID; + } else if (type == "relu") { + return ActivationType::RELU; + } else if (type == "tanh") { + return ActivationType::TANH; + } else if (type == "identity" || type == "") { + return ActivationType::IDENTITY; + } + PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); +} + #if defined(__ARM_NEON__) || defined(__ARM_NEON) -template +template inline float32x4_t vActiveq_f32(const float32x4_t &x) { return x; } template <> -inline float32x4_t vActiveq_f32(const float32x4_t &x) { +inline float32x4_t vActiveq_f32(const float32x4_t &x) { float32x4_t __zero = vdupq_n_f32(0.f); return vmaxq_f32(x, __zero); } template <> -inline float32x4_t vActiveq_f32(const float32x4_t &x) { +inline float32x4_t vActiveq_f32(const float32x4_t &x) { float32x4_t __zero = vdupq_n_f32(0.f); float32x4_t __six = vdupq_n_f32(6.f); return vminq_f32(vmaxq_f32(x, __zero), __six); } + +template <> +inline float32x4_t vActiveq_f32(const float32x4_t &x) { + float32x4_t __one = vdupq_n_f32(1.f); + float32x4_t __x = vnegq_f32(x); + __x = exp_ps(__x); + __x = vaddq_f32(__x, __one); + float32x4_t __out = vrecpeq_f32(__x); + return vmulq_f32(vrecpsq_f32(__x, __out), __out); +} + +template <> +inline float32x4_t vActiveq_f32(const float32x4_t &x) { + float32x4_t __one = vdupq_n_f32(1.f); + float32x4_t __x = vnegq_f32(x); + __x = vmulq_n_f32(__x, 2.f); + __x = exp_ps(__x); + __x = vaddq_f32(__x, __one); + float32x4_t __out = vrecpeq_f32(__x); + __out = vmulq_f32(vrecpsq_f32(__x, __out), __out); + __out = vmulq_n_f32(__out, 2.f); + return vsubq_f32(__out, __one); +} #endif -template +template inline float Active(const float &x) { return x; } template <> -inline float Active(const float &x) { +inline float Active(const float &x) { return std::max(x, 0.f); } template <> -inline float Active(const float &x) { +inline float Active(const float &x) { return std::min(std::max(x, 0.f), 6.f); } +template <> +inline float Active(const float &x) { + // float tmp = x > SIGMOID_THRESHOLD_MAX ? SIGMOID_THRESHOLD_MAX : x; + // tmp = x > SIGMOID_THRESHOLD_MIN ? x : SIGMOID_THRESHOLD_MIN; + // return 1.f / (1.f + exp(-tmp)); + return 1.f / (1.f + exp(-x)); +} + +template <> +inline float Active(const float &x) { + // float tmp = -2.f * x; + // tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + // return (2.f / (1.f + exp(tmp))) - 1.f; + return 2.f / (1.f + exp(-2.f * x)) - 1.f; +} + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/activation_functions.h b/src/operators/math/activation_functions.h deleted file mode 100644 index 8604065a2570cc17c970c487fcaa898f78c72a85..0000000000000000000000000000000000000000 --- a/src/operators/math/activation_functions.h +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "common/enforce.h" -namespace paddle_mobile { -namespace operators { -namespace math { - -#define SIGMOID_THRESHOLD_MIN -40.0 -#define SIGMOID_THRESHOLD_MAX 13.0 -#define EXP_MAX_INPUT 40.0 - -enum ActivationType { - kSigmoid, - kReLU, - kTanh, - kIdentity, -}; - -inline ActivationType GetActivationType(const std::string &type) { - if (type == "sigmoid") { - return ActivationType::kSigmoid; - } else if (type == "relu") { - return ActivationType::kReLU; - } else if (type == "tanh") { - return ActivationType::kTanh; - } else if (type == "identity" || type == "") { - return ActivationType::kIdentity; - } - PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); -} - -namespace forward { - -template -T Identity(const T a) { - return a; -} - -template -T Relu(const T a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -template -T Sigmoid(const T a) { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - T tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -template -T Tanh(const T a) { - T tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -} // namespace forward - -template -struct Active { - typedef T (*Act)(T); -}; - -static Active::Act kActFloat[] = { - &forward::Sigmoid, &forward::Relu, &forward::Tanh, - &forward::Identity}; - -namespace forward { -inline float activation(float a, int index) { return kActFloat[index](a); } - -} // namespace forward - -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index c17b2a5e4df0f0ca88da79a9ce55c2ecae0316b5..69c09e255b75f51430a2c6f477bb4e439608290e 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -1260,10 +1260,10 @@ void Gemm::AddDot4x4(int k, const float *a, const float *b, float *c, int ldc) { "q10", "q11", "q12", "q13"); } -/* -void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, int -lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float -*bufferC = static_cast(memory::Alloc(sizeof(float) * n)); +void Gemm::VectorKernel(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, float *C, + int ldc, bool relu) { + float *bufferC = static_cast(memory::Alloc(sizeof(float) * n)); const float *a0, *b0, *b1, *b2, *b3; float *c0, *C0; @@ -1482,6 +1482,7 @@ lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu) { float } } +/* void Gemm::VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *new_scale, float *new_bias) { @@ -2579,278 +2580,278 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, } } - /* - // C = A * B - void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); +// C = A * B +void Gemm::VecWriteBasic(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vst1.32 {q0, q1}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vst1.32 {q0, q1}, [%[C]]! \n\t" - "vld1.32 {q2, q3}, [%[c]]! \n\t" - "vst1.32 {q2, q3}, [%[C]]! \n\t" + "vld1.32 {q2, q3}, [%[c]]! \n\t" + "vst1.32 {q2, q3}, [%[C]]! \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" - "vld1.32 {q4}, [%[c]]! \n\t" - "vst1.32 {q4}, [%[C]]! \n\t" + "vld1.32 {q4}, [%[c]]! \n\t" + "vst1.32 {q4}, [%[C]]! \n\t" - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - "sub %[c], %[c], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - "vld1.32 {q5}, [%[c]]! \n\t" - "vst1.32 {q5}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + "sub %[c], %[c], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + "vld1.32 {q5}, [%[c]]! \n\t" + "vst1.32 {q5}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); - } + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] "r"(nc3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +} - // C = alpha * A * B + beta * C - void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} +// C = alpha * A * B + beta * C +void Gemm::VecWriteWithAlphaBeta(int n, float *c, float *C, int ldc) {} - // C = A * B + C - void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; +// C = A * B + C +void Gemm::VecWriteWithAdd(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[C]] \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[C]] \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[C]] \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[C]] \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - : [C] "+r"(C), [c] "+r"(c) - : [nc1] "r"(nc1) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", - "q11", "q12", "q13"); + : [C] "+r"(C), [c] "+r"(c) + : [nc1] "r"(nc1) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); - if (_nc1 != 0) { - for (int j = 0; j < _nc1; j++) { - *C++ += *c++; - } + if (_nc1 != 0) { + for (int j = 0; j < _nc1; j++) { + *C++ += *c++; } } +} - // C = A * B + C, relu(C) - void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { - int nc1 = n / 16; - int _nc1 = n % 16; +// C = A * B + C, relu(C) +void Gemm::VecWriteWithAddRelu(int n, float *c, float *C, int ldc) { + int nc1 = n / 16; + int _nc1 = n % 16; - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[C]] \n\t" - "vadd.f32 q10, q0, q2 \n\t" - "vadd.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[C]] \n\t" + "vadd.f32 q10, q0, q2 \n\t" + "vadd.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[C]] \n\t" - "vadd.f32 q12, q4, q6 \n\t" - "vadd.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[C]] \n\t" + "vadd.f32 q12, q4, q6 \n\t" + "vadd.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" - : [C] "+r"(C), [c] "+r"(c) - : [nc1] "r"(nc1) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", - "q11", "q12", "q13"); + : [C] "+r"(C), [c] "+r"(c) + : [nc1] "r"(nc1) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q10", "q11", + "q12", "q13"); - if (_nc1 != 0) { - for (int j = 0; j < _nc1; j++) { - *C += *c; - if (*C < 0) { - *C = 0; - } - C++; - c++; + if (_nc1 != 0) { + for (int j = 0; j < _nc1; j++) { + *C += *c; + if (*C < 0) { + *C = 0; } + C++; + c++; } } +} - // C = A * B, batchnorm(C) - void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, - float *bias) { - int nc1 = n / 16; - int _nc1 = n % 16; - int nc2 = _nc1 / 4; - int nc3 = 16 - 4 * (_nc1 % 4); - - asm volatile( - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" - - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] - "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", - "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13"); - } - - // C = A * B, batchnorm(C), relu(C) - void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float - *scale, float *bias) { int nc1 = n / 16; int _nc1 = n % 16; int nc2 = _nc1 / - 4; int nc3 = 16 - 4 * (_nc1 % 4); - - asm volatile( - "vmov.f32 q14, #0.0 \n\t" - "subs %[nc1], %[nc1], #1 \n\t" - "blt end_nc1_%= \n\t" - "loop_nc1_%=: \n\t" - - "vld1.32 {q0, q1}, [%[c]]! \n\t" - "vld1.32 {q2, q3}, [%[scale]]! \n\t" - "vld1.32 {q10, q11}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q2 \n\t" - "vmla.f32 q11, q1, q3 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vmax.f32 q11, q11, q14 \n\t" - "vst1.32 {q10, q11}, [%[C]]! \n\t" - - "vld1.32 {q4, q5}, [%[c]]! \n\t" - "vld1.32 {q6, q7}, [%[scale]]! \n\t" - "vld1.32 {q12, q13}, [%[bias]]! \n\t" - "vmla.f32 q12, q4, q6 \n\t" - "vmla.f32 q13, q5, q7 \n\t" - "vmax.f32 q12, q12, q14 \n\t" - "vmax.f32 q13, q13, q14 \n\t" - "vst1.32 {q12, q13}, [%[C]]! \n\t" - - "subs %[nc1], %[nc1], #1 \n\t" - "bge loop_nc1_%= \n\t" - "end_nc1_%=: \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "blt end_nc2_%= \n\t" - "loop_nc2_%=: \n\t" - - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - - "subs %[nc2], %[nc2], #1 \n\t" - "bge loop_nc2_%= \n\t" - "end_nc2_%=: \n\t" - - "cmp %[nc3], #16 \n\t" - "beq end_nc3_%= \n\t" - - "sub %[c], %[c], %[nc3] \n\t" - "sub %[scale], %[scale], %[nc3] \n\t" - "sub %[bias], %[bias], %[nc3] \n\t" - "sub %[C], %[C], %[nc3] \n\t" + /* + // C = A * B, batchnorm(C) + void Gemm::VecWriteWithBn(int n, float *c, float *C, int ldc, float *scale, + float *bias) { + int nc1 = n / 16; + int _nc1 = n % 16; + int nc2 = _nc1 / 4; + int nc3 = 16 - 4 * (_nc1 % 4); - "vld1.32 {q0}, [%[c]]! \n\t" - "vld1.32 {q1}, [%[scale]]! \n\t" - "vld1.32 {q10}, [%[bias]]! \n\t" - "vmla.f32 q10, q0, q1 \n\t" - "vmax.f32 q10, q10, q14 \n\t" - "vst1.32 {q10}, [%[C]]! \n\t" - "end_nc3_%=: \n\t" + asm volatile( + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[scale]]! \n\t" + "vld1.32 {q10, q11}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q2 \n\t" + "vmla.f32 q11, q1, q3 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" + + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[scale]]! \n\t" + "vld1.32 {q12, q13}, [%[bias]]! \n\t" + "vmla.f32 q12, q4, q6 \n\t" + "vmla.f32 q13, q5, q7 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[scale], %[scale], %[nc3] \n\t" + "sub %[bias], %[bias], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] + "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", + "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13"); + } + + // C = A * B, batchnorm(C), relu(C) + void Gemm::VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float + *scale, float *bias) { int nc1 = n / 16; int _nc1 = n % 16; int nc2 = _nc1 / + 4; int nc3 = 16 - 4 * (_nc1 % 4); - : - : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] - "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", - "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14"); - } - */ + asm volatile( + "vmov.f32 q14, #0.0 \n\t" + "subs %[nc1], %[nc1], #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c]]! \n\t" + "vld1.32 {q2, q3}, [%[scale]]! \n\t" + "vld1.32 {q10, q11}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q2 \n\t" + "vmla.f32 q11, q1, q3 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vmax.f32 q11, q11, q14 \n\t" + "vst1.32 {q10, q11}, [%[C]]! \n\t" + + "vld1.32 {q4, q5}, [%[c]]! \n\t" + "vld1.32 {q6, q7}, [%[scale]]! \n\t" + "vld1.32 {q12, q13}, [%[bias]]! \n\t" + "vmla.f32 q12, q4, q6 \n\t" + "vmla.f32 q13, q5, q7 \n\t" + "vmax.f32 q12, q12, q14 \n\t" + "vmax.f32 q13, q13, q14 \n\t" + "vst1.32 {q12, q13}, [%[C]]! \n\t" + + "subs %[nc1], %[nc1], #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "blt end_nc2_%= \n\t" + "loop_nc2_%=: \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + + "subs %[nc2], %[nc2], #1 \n\t" + "bge loop_nc2_%= \n\t" + "end_nc2_%=: \n\t" + + "cmp %[nc3], #16 \n\t" + "beq end_nc3_%= \n\t" + + "sub %[c], %[c], %[nc3] \n\t" + "sub %[scale], %[scale], %[nc3] \n\t" + "sub %[bias], %[bias], %[nc3] \n\t" + "sub %[C], %[C], %[nc3] \n\t" + + "vld1.32 {q0}, [%[c]]! \n\t" + "vld1.32 {q1}, [%[scale]]! \n\t" + "vld1.32 {q10}, [%[bias]]! \n\t" + "vmla.f32 q10, q0, q1 \n\t" + "vmax.f32 q10, q10, q14 \n\t" + "vst1.32 {q10}, [%[C]]! \n\t" + "end_nc3_%=: \n\t" + + : + : [C] "r"(C), [c] "r"(c), [nc1] "r"(nc1), [nc2] "r"(nc2), [nc3] + "r"(nc3), [scale] "r"(scale), [bias] "r"(bias) : "memory", "q0", "q1", "q2", + "q3", "q4", "q5", "q6", "q7", "q10", "q11", "q12", "q13", "q14"); + } + */ #endif // __aarch64__ #else @@ -3149,13 +3150,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu, float *bias) { + if (m == 1 && bias == nullptr) { + return VectorKernel(m, n, k, alpha, A, lda, B, ldb, beta, C, ldc, relu); + } #ifdef _OPENMP int max_threads = omp_get_max_threads(); #else int max_threads = 1; #endif - int L1 = 64 / max_threads * 1024; + // int L1 = 64 / max_threads * 1024; + int L1 = 32 / max_threads * 1024; KC = k; zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); memset(static_cast(zero), 0, sizeof(float) * KC); diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index fb2c248c9b38b6ef62fe477930cf83060b95ee1d..99c68de7c34cf6300625cbc1cee3c2b388ffb517 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -105,16 +105,15 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, float *c, float *C, int ldc, float *p, std::string mode, float *bias, float *bias1); - /* // 向量矩阵乘法 (M = 1) void VectorKernel(int m, int n, int k, float alpha, const float *A, int lda, const float *B, int ldb, float beta, float *C, int ldc, bool relu); - - void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, - int lda, const float *B, int ldb, float beta, float - *C, int ldc, bool relu, float *new_scale, float *new_bias); - */ + /* + void VectorKernelWithBn(int m, int n, int k, float alpha, const float *A, + int lda, const float *B, int ldb, float beta, float + *C, int ldc, bool relu, float *new_scale, float *new_bias); + */ // 计算一个更小的 C 矩阵分块 void AddDot4x4(int k, const float *a, const float *b, float *c, int ldc); @@ -149,7 +148,6 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc, float *new_scale, float *new_bias, float *bias1); - /* // 向量矩阵乘法结果回写 // C = A * B void VecWriteBasic(int n, float *c, float *C, int ldc); @@ -159,13 +157,14 @@ void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb, void VecWriteWithAdd(int n, float *c, float *C, int ldc); // C = A * B + C, relu(C) void VecWriteWithAddRelu(int n, float *c, float *C, int ldc); - // C = A * B, batchnorm(C) - void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, - float *new_bias); - // C = A * B, batchnorm(C), relu(C) - void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float *new_scale, - float *new_bias); - */ + /* + // C = A * B, batchnorm(C) + void VecWriteWithBn(int n, float *c, float *C, int ldc, float *new_scale, + float *new_bias); + // C = A * B, batchnorm(C), relu(C) + void VecWriteWithBnRelu(int n, float *c, float *C, int ldc, float + *new_scale, float *new_bias); + */ // 32位 float 矩阵乘法 void Sgemm(int m, int n, int k, float alpha, const float *A, int lda, @@ -392,7 +391,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, packedB_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); #endif @@ -414,7 +413,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, packedA_int8 = static_cast( paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); #endif @@ -438,7 +437,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, int8_t *local_A = packedA_int8 + MC * KC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); #endif @@ -468,7 +467,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, int8_t *local_B = packedB_int8 + KC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads; #if __aarch64__ - // TODO() + // TODO(paddle mobile) #else PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); #endif diff --git a/src/operators/math/gru_compute.cpp b/src/operators/math/gru_compute.cpp index bbf1b01a21a980293f3cfe255885e7127aeb208e..19c7a2685c347340a0d3bd10b1c5828bfd437d4f 100644 --- a/src/operators/math/gru_compute.cpp +++ b/src/operators/math/gru_compute.cpp @@ -11,13 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef GRU_OP + #include "operators/math/gru_compute.h" #include "common/types.h" -#include "operators/math/activation_functions.h" +#include "operators/math/activation.h" #include "operators/math/gemm.h" #include "operators/math/gru_cpu_kernel.h" -#include "operators/math/gru_kernel.h" namespace paddle_mobile { namespace operators { @@ -43,8 +44,7 @@ struct GRUUnitFunctor { #endif } - forward_reset_output(forward::gru_resetOutput(), value, frame_size, - batch_size, active_gate); + forward_reset_output(value, frame_size, batch_size, active_gate); if (value.prev_out_value) { #ifdef _OPENMP @@ -60,8 +60,7 @@ struct GRUUnitFunctor { #endif } - forward_final_output(forward::gru_finalOutput(), value, frame_size, - batch_size, active_node); + forward_final_output(value, frame_size, batch_size, active_node); } }; diff --git a/src/operators/math/gru_compute.h b/src/operators/math/gru_compute.h index 89cac1b8e49cd11eec551ba60f54e72f3912c846..00f4da90222f4f9d492a8214ee37828aac7aaf2d 100644 --- a/src/operators/math/gru_compute.h +++ b/src/operators/math/gru_compute.h @@ -11,7 +11,7 @@ limitations under the License. */ #ifdef GRU_OP #pragma once -#include "operators/math/activation_functions.h" +#include "operators/math/activation.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/math/gru_cpu_kernel.h b/src/operators/math/gru_cpu_kernel.h index ea24c4f1d97ebfbc5454e118121a3c79f28008c6..a010fb616b2c222e1ab9c7bfb248aad35d9b0e97 100644 --- a/src/operators/math/gru_cpu_kernel.h +++ b/src/operators/math/gru_cpu_kernel.h @@ -11,21 +11,22 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #ifdef GRU_OP + #pragma once + #include -#include "operators/math/activation_functions.h" +#include "operators/math/activation.h" #include "operators/math/gru_compute.h" namespace paddle_mobile { namespace operators { namespace math { -template -void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, - T *gate_value, T *reset_output_value, - T *prev_output_value, int frame_size, - ActivationType active_gate) { +template +void hl_naive_gru_forward_reset_output(T *gate_value, T *reset_output_value, + T *prev_output_value, int frame_size) { T r_value_update_gate; T r_value_reset_gate; T r_value_reset_output; @@ -33,27 +34,57 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, T *update_gate = gate_value; T *reset_gate = gate_value + frame_size; - for (int i = 0; i < frame_size; i++) { + int remain = frame_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = remain >> 3; + remain = remain & 0x7; + float32x4_t prev0 = vdupq_n_f32(0.f); + float32x4_t prev1 = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i) { + float32x4_t update0 = vld1q_f32(update_gate); + float32x4_t update1 = vld1q_f32(update_gate + 4); + float32x4_t reset0 = vld1q_f32(reset_gate); + float32x4_t reset1 = vld1q_f32(reset_gate + 4); + if (prev_output_value) { + prev0 = vld1q_f32(prev_output_value); + prev1 = vld1q_f32(prev_output_value + 4); + prev_output_value += 8; + } + update0 = vActiveq_f32(update0); + update1 = vActiveq_f32(update1); + reset0 = vActiveq_f32(reset0); + reset1 = vActiveq_f32(reset1); + float32x4_t output0 = vmulq_f32(prev0, reset0); + float32x4_t output1 = vmulq_f32(prev1, reset1); + vst1q_f32(update_gate, update0); + vst1q_f32(update_gate + 4, update1); + vst1q_f32(reset_gate, reset0); + vst1q_f32(reset_gate + 4, reset1); + vst1q_f32(reset_output_value, output0); + vst1q_f32(reset_output_value + 4, output1); + update_gate += 8; + reset_gate += 8; + reset_output_value += 8; + } +#endif // __ARM_NEON__ + for (int i = 0; i < remain; i++) { r_value_update_gate = update_gate[i]; r_value_reset_gate = reset_gate[i]; if (prev_output_value) { r_prev_out = prev_output_value[i]; } - - op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, - &r_value_reset_output, active_gate); - + r_value_update_gate = Active(r_value_update_gate); + r_value_reset_gate = Active(r_value_reset_gate); + r_value_reset_output = r_prev_out * r_value_reset_gate; update_gate[i] = r_value_update_gate; reset_gate[i] = r_value_reset_gate; reset_output_value[i] = r_value_reset_output; } } -template -void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, - T *gate_value, T *prev_output_value, - T *output_value, int frame_size, - ActivationType active_node) { +template +void hl_naive_gru_forward_final_output(T *gate_value, T *prev_output_value, + T *output_value, int frame_size) { T r_value_update_gate; T r_value_frame_state; T r_prev_out = 0; @@ -61,30 +92,73 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, T *update_gate = gate_value; T *frame_state = gate_value + frame_size * 2; - for (int i = 0; i < frame_size; i++) { + int remain = frame_size; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + int loop = remain >> 3; + remain = remain & 0x7; + float32x4_t prev0 = vdupq_n_f32(0.f); + float32x4_t prev1 = vdupq_n_f32(0.f); + for (int i = 0; i < loop; ++i) { + float32x4_t update0 = vld1q_f32(update_gate); + float32x4_t update1 = vld1q_f32(update_gate + 4); + float32x4_t state0 = vld1q_f32(frame_state); + float32x4_t state1 = vld1q_f32(frame_state + 4); + if (prev_output_value) { + prev0 = vld1q_f32(prev_output_value); + prev1 = vld1q_f32(prev_output_value + 4); + prev_output_value += 8; + } + state0 = vActiveq_f32(state0); + state1 = vActiveq_f32(state1); + float32x4_t output0 = vmlsq_f32(prev0, update0, prev0); + float32x4_t output1 = vmlsq_f32(prev1, update1, prev1); + output0 = vmlaq_f32(output0, update0, state0); + output1 = vmlaq_f32(output1, update1, state1); + vst1q_f32(frame_state, state0); + vst1q_f32(frame_state + 4, state1); + vst1q_f32(output_value, output0); + vst1q_f32(output_value + 4, output1); + update_gate += 8; + frame_state += 8; + output_value += 8; + } +#endif // __ARM_NEON__ + for (int i = 0; i < remain; i++) { r_value_update_gate = update_gate[i]; r_value_frame_state = frame_state[i]; if (prev_output_value) { r_prev_out = prev_output_value[i]; } - - op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, - &r_output, active_node); - + r_value_frame_state = Active(r_value_frame_state); + r_output = r_prev_out - r_value_update_gate * r_prev_out + + r_value_update_gate * r_value_frame_state; frame_state[i] = r_value_frame_state; output_value[i] = r_output; } } -template -inline void forward_reset_output(OpResetOutput op_reset_output, - GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_gate) { - for (int b = 0; b < batch_size; b++) { - hl_naive_gru_forward_reset_output( - op_reset_output, value.gate_value, value.reset_output_value, - value.prev_out_value, frame_size, active_gate); +#define FORWARD_RESET_OUTPUT(active_type, value, frame_size) \ + hl_naive_gru_forward_reset_output( \ + value.gate_value, value.reset_output_value, value.prev_out_value, \ + frame_size); +template +inline void forward_reset_output(GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { + for (int b = 0; b < batch_size; ++b) { + switch (active_node) { + case RELU: + FORWARD_RESET_OUTPUT(RELU, value, frame_size); + break; + case SIGMOID: + FORWARD_RESET_OUTPUT(SIGMOID, value, frame_size); + break; + case TANH: + FORWARD_RESET_OUTPUT(TANH, value, frame_size); + break; + default: + FORWARD_RESET_OUTPUT(IDENTITY, value, frame_size); + } value.gate_value += frame_size * 3; value.reset_output_value += frame_size; if (value.prev_out_value) { @@ -93,15 +167,27 @@ inline void forward_reset_output(OpResetOutput op_reset_output, } } -template -inline void forward_final_output(OpFinalOutput op_final_output, - GRUMetaValue value, int frame_size, - int batch_size, ActivationType active_node) { - for (int b = 0; b < batch_size; b++) { - hl_naive_gru_forward_final_output(op_final_output, value.gate_value, - value.prev_out_value, value.output_value, - frame_size, active_node); +#define FORWARD_FINAL_OUTPUT(active_type, value, frame_size) \ + hl_naive_gru_forward_final_output( \ + value.gate_value, value.prev_out_value, value.output_value, frame_size) +template +inline void forward_final_output(GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { + for (int b = 0; b < batch_size; ++b) { + switch (active_node) { + case RELU: + FORWARD_FINAL_OUTPUT(RELU, value, frame_size); + break; + case SIGMOID: + FORWARD_FINAL_OUTPUT(SIGMOID, value, frame_size); + break; + case TANH: + FORWARD_FINAL_OUTPUT(TANH, value, frame_size); + break; + default: + FORWARD_FINAL_OUTPUT(IDENTITY, value, frame_size); + } value.gate_value += frame_size * 3; value.output_value += frame_size; if (value.prev_out_value) { @@ -113,4 +199,5 @@ inline void forward_final_output(OpFinalOutput op_final_output, } // namespace math } // namespace operators } // namespace paddle_mobile + #endif diff --git a/src/operators/math/gru_kernel.h b/src/operators/math/gru_kernel.h deleted file mode 100644 index 6113ce8da997eaa5720886d637a9cc9261ea5227..0000000000000000000000000000000000000000 --- a/src/operators/math/gru_kernel.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef GRU_OP -#pragma once -#include -#include "operators/math/activation_functions.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -namespace forward { - -template -class gru_resetOutput { - public: - void operator()(T *value_update_gate, T *value_reset_gate, T *prev_out, - T *value_reset_output, ActivationType act_gate) { - *value_update_gate = activation(*value_update_gate, act_gate); - *value_reset_gate = activation(*value_reset_gate, act_gate); - *value_reset_output = (*prev_out) * (*value_reset_gate); - } -}; - -template -class gru_finalOutput { - public: - void operator()(T *value_update_gate, T *value_frame_state, T *prev_out, - T *value_output, ActivationType act_input) { - *value_frame_state = activation(*value_frame_state, act_input); - *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + - ((*value_update_gate) * (*value_frame_state)); - } -}; -} // namespace forward - -} // namespace math -} // namespace operators -} // namespace paddle_mobile -#endif diff --git a/src/operators/math/pooling.cpp b/src/operators/math/pooling.cpp index 1270e6a8980f44ef1f337f031d170eea1c536df0..46b4453e73dfac0ab8b5755e1c5cd472584be0a8 100644 --- a/src/operators/math/pooling.cpp +++ b/src/operators/math/pooling.cpp @@ -72,8 +72,8 @@ void Pooling

::operator()(const framework::Tensor &input, } } -template struct Pooling; -template struct Pooling; +template struct Pooling; +template struct Pooling; } // namespace math } // namespace operators diff --git a/src/operators/math/pooling.h b/src/operators/math/pooling.h index 4a3b2b8389eeac708df73d97672682615d1f7912..4239cf8cbc87e786e2e07ac77614f4d2f96d73dd 100644 --- a/src/operators/math/pooling.h +++ b/src/operators/math/pooling.h @@ -30,7 +30,7 @@ namespace paddle_mobile { namespace operators { namespace math { -template +template struct PoolingVal { float val; int count; @@ -44,11 +44,11 @@ struct PoolingVal { }; template <> -struct PoolingVal { +struct PoolingVal { float val; int count; PoolingVal() : val(0.f), count(0) {} - inline PoolingVal &operator+=(const float &x) { + inline PoolingVal &operator+=(const float &x) { val += x; ++count; return *this; @@ -57,57 +57,57 @@ struct PoolingVal { }; #if defined(__ARM_NEON) || defined(__ARM_NEON__) -template +template inline float32x4_t vPoolInitq_f32() { return vdupq_n_f32(-std::numeric_limits::max()); } template <> -inline float32x4_t vPoolInitq_f32() { +inline float32x4_t vPoolInitq_f32() { return vdupq_n_f32(0.f); } -template +template inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { return vmaxq_f32(x1, x2); } template <> -inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, +inline float32x4_t vPoolPreq_f32(const float32x4_t &x1, const float32x4_t &x2) { return vaddq_f32(x1, x2); } -template +template inline float32x4_t vPoolPostq_f32(const float32x4_t &x, const float32x4_t &post) { return x; } template <> -inline float32x4_t vPoolPostq_f32(const float32x4_t &x, +inline float32x4_t vPoolPostq_f32(const float32x4_t &x, const float32x4_t &post) { return vmulq_f32(x, post); } #endif // __ARM_NEON__ -template +template inline float PoolPre(const float &x1, const float &x2) { return std::max(x1, x2); } template <> -inline float PoolPre(const float &x1, const float &x2) { +inline float PoolPre(const float &x1, const float &x2) { return x1 + x2; } -template +template inline float PoolPost(const float &x, const float &post) { return x; } template <> -inline float PoolPost(const float &x, const float &post) { +inline float PoolPost(const float &x, const float &post) { return x * post; } diff --git a/src/operators/math/pooling3x3.cpp b/src/operators/math/pooling3x3.cpp index e556768ce04cd4326d11c10f00aec00e2bd263f8..72ffb6161a96fbde432768fbda455cf4d869de61 100644 --- a/src/operators/math/pooling3x3.cpp +++ b/src/operators/math/pooling3x3.cpp @@ -1016,10 +1016,10 @@ struct Pooling3x3 { } }; -template struct Pooling3x3; -template struct Pooling3x3; -template struct Pooling3x3; -template struct Pooling3x3; +template struct Pooling3x3; +template struct Pooling3x3; +template struct Pooling3x3; +template struct Pooling3x3; } // namespace math } // namespace operators diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index 3668b8cb2846ba4e3cb7f3b0728c1356a343cf4c..5d3c4374a403e0f3050d9b9babd3d09bdff03bc9 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -74,11 +74,11 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { output_cmp.mutable_data(output->dims()); if (pooling_type == "avg") { - math::Pooling()(*input, std::vector{kernel_h, kernel_w}, + math::Pooling()(*input, std::vector{kernel_h, kernel_w}, std::vector{stride_h, stride_w}, std::vector{pad_h, pad_w}, &output_cmp); } else { - math::Pooling()(*input, std::vector{kernel_h, kernel_w}, + math::Pooling()(*input, std::vector{kernel_h, kernel_w}, std::vector{stride_h, stride_w}, std::vector{pad_h, pad_w}, &output_cmp); } @@ -117,57 +117,57 @@ int main(int argc, char *argv[]) { int in_channels = atoi(argv[1]); int in_height = atoi(argv[2]); int in_width = atoi(argv[3]); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=0, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=1, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=2, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=5, stride=1"; - // paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; - // paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=0, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 0, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=1, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 1, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=2, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 2, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=5, stride=1"; + paddle_mobile::TestPoolOp<0, 3, 5, 1>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=0, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 0, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=1, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 1, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=2, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 2, 1>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=5, stride=1"; + paddle_mobile::TestPoolOp<1, 3, 5, 1>(in_channels, in_height, in_width); LOG(paddle_mobile::kLOG_INFO) << "float, pooling_type=max, kernel=3, pad=0, stride=2"; paddle_mobile::TestPoolOp<0, 3, 0, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=1, stride=2"; - // paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=2, stride=2"; - // paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=max, kernel=3, pad=5, stride=2"; - // paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); - // - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); - // LOG(paddle_mobile::kLOG_INFO) - // << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; - // paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=1, stride=2"; + paddle_mobile::TestPoolOp<0, 3, 1, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=2, stride=2"; + paddle_mobile::TestPoolOp<0, 3, 2, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=max, kernel=3, pad=5, stride=2"; + paddle_mobile::TestPoolOp<0, 3, 5, 2>(in_channels, in_height, in_width); + + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=0, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 0, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=1, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 1, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=2, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 2, 2>(in_channels, in_height, in_width); + LOG(paddle_mobile::kLOG_INFO) + << "float, pooling_type=avg, kernel=3, pad=5, stride=2"; + paddle_mobile::TestPoolOp<1, 3, 5, 2>(in_channels, in_height, in_width); // // kernel = 5, pad = 0, stride = 1 // LOG(paddle_mobile::kLOG_INFO)