diff --git a/lite/backends/arm/math/gemm_prepacked_int8.cc b/lite/backends/arm/math/gemm_prepacked_int8.cc index f45ad274153ad7c932e2fa888ecf95519391ded3..d3150ac58dd6ae9200e9652c28c5ba117610155e 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.cc +++ b/lite/backends/arm/math/gemm_prepacked_int8.cc @@ -4252,18 +4252,18 @@ void gemm_prepack_int8(const int8_t* A_packed, } #else gemm_prepack_oth_int8(A_packed, - B, - bias, - C, - M, - N, - K, - is_bias, - flag_act, - is_transB, - scale, - alpha, - ctx); + B, + bias, + C, + M, + N, + K, + is_bias, + flag_act, + is_transB, + scale, + alpha, + ctx); #endif } @@ -4319,22 +4319,33 @@ void gemm_prepack_int8(const int8_t* A_packed, ctx); } else { gemm_prepack_oth_int8(A_packed, - B, - bias, - C, - M, - N, - K, - is_bias, - flag_act, - is_transB, - scale, - alpha, - ctx); + B, + bias, + C, + M, + N, + K, + is_bias, + flag_act, + is_transB, + scale, + alpha, + ctx); } #else - gemm_prepack_oth_int8( - A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx); + gemm_prepack_oth_int8(A_packed, + B, + bias, + C, + M, + N, + K, + is_bias, + flag_act, + is_transB, + scale, + alpha, + ctx); #endif } diff --git a/lite/backends/arm/math/gemv_arm_int8.cc b/lite/backends/arm/math/gemv_arm_int8.cc index 86f7f47d8d9a7ab2ae4091215679195917f8c353..f9843a787db970f58bef6a08136426de969306cb 100644 --- a/lite/backends/arm/math/gemv_arm_int8.cc +++ b/lite/backends/arm/math/gemv_arm_int8.cc @@ -65,17 +65,15 @@ inline void write_gemv_out(const int* in, vout1 = vmaxq_f32(vout1, vzero); vout0 = vminq_f32(vout0, vsix); vout1 = vminq_f32(vout1, vsix); + } else if (act == lite_api::ActivationType::kLeakyRelu) { + float32x4_t valpha = vdupq_n_f32(alpha); + uint32x4_t maska = vcgeq_f32(vout0, vzero); + uint32x4_t maskb = vcgeq_f32(vout1, vzero); + float32x4_t suma = vmulq_f32(vout0, valpha); + float32x4_t sumb = vmulq_f32(vout1, valpha); + vout0 = vbslq_f32(maska, vout0, suma); + vout1 = vbslq_f32(maskb, vout1, sumb); } - vout0 = vmaxq_f32(vout0, vzero); - vout1 = vmaxq_f32(vout1, vzero); - } else if (act == lite_api::ActivationType::kLeakyRelu) { - float32x4_t valpha = vdupq_n_f32(alpha); - uint32x4_t maska = vcgeq_f32(vout0, vzero); - uint32x4_t maskb = vcgeq_f32(vout1, vzero); - float32x4_t suma = vmulq_f32(vout0, valpha); - float32x4_t sumb = vmulq_f32(vout1, valpha); - vout0 = vbslq_f32(maska, vout0, suma); - vout1 = vbslq_f32(maskb, vout1, sumb); } vst1q_f32(out, vout0); vst1q_f32(out + 4, vout1); diff --git a/lite/tests/math/gemm_int8_compute_test.cc b/lite/tests/math/gemm_int8_compute_test.cc index 5c5a19a95c373586c99325fd79fe5ff39fe90c95..ef8e59261b17e29b8a7f17d10b981922c4cdf370 100644 --- a/lite/tests/math/gemm_int8_compute_test.cc +++ b/lite/tests/math/gemm_int8_compute_test.cc @@ -23,6 +23,7 @@ #include "lite/core/profile/timer.h" #include "lite/core/tensor.h" #include "lite/tests/utils/tensor_utils.h" +#include "lite/operators/op_params.h" typedef paddle::lite::Tensor Tensor; using paddle::lite::profile::Timer;