From 2562652a111b2497e0064cd2595f582103577b4a Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Thu, 7 May 2020 18:15:53 +0800 Subject: [PATCH] fix relu6 bug --- lite/backends/arm/math/gemv_arm_int8.cc | 8 ++++---- lite/tests/math/gemv_int8_compute_test.cc | 7 ++++--- lite/tests/math/sgemv_compute_test.cc | 5 +++++ lite/tests/utils/naive_math_impl.h | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lite/backends/arm/math/gemv_arm_int8.cc b/lite/backends/arm/math/gemv_arm_int8.cc index 8978e61c70..ea271d08bd 100644 --- a/lite/backends/arm/math/gemv_arm_int8.cc +++ b/lite/backends/arm/math/gemv_arm_int8.cc @@ -155,10 +155,10 @@ bool gemv_int8_oth(const int8_t* A, const float* scale, bool is_bias, const float* bias, - bool flag_act, - lite_api::ActivationType act, - float six, - float alpha) { + bool flag_act, + lite_api::ActivationType act, + float six, + float alpha) { if (transA) { LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; return false; diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc index 1e80464ad1..9e29617bb4 100644 --- a/lite/tests/math/gemv_int8_compute_test.cc +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -155,11 +155,12 @@ bool test_gemv_int8(bool tra, 1, 1, tc_basic_fp32.numel()); - if (flag_act == 2) { // relu6 + if (flag_act == 2) { // relu6 for (int i = 0; i < tc_basic_int8.numel(); i++) { - dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; + dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i]; + dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; } - } + } } Timer t0; //! compute diff --git a/lite/tests/math/sgemv_compute_test.cc b/lite/tests/math/sgemv_compute_test.cc index 91a1fe1770..88277c0676 100644 --- a/lite/tests/math/sgemv_compute_test.cc +++ b/lite/tests/math/sgemv_compute_test.cc @@ -108,6 +108,11 @@ bool test_sgemv(bool tra, flag_act, six, alpha); + if (flag_act == 2) { // relu6 + for (int i = 0; i < tc_basic.numel(); i++) { + dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i]; + } + } } paddle::lite::profile::Timer t0; //! compute diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index e5ef77ca06..916246f708 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -202,7 +202,7 @@ static void basic_gemv(int m, c[i] = tmp > (type2)0 ? tmp : (type2)0; } else if (flag_act == 2) { // relu 6 c[i] = tmp > (type2)0 ? tmp : (type2)0; - c[i] = c[i] < six ? c[i] : six; + // c[i] = c[i] < six ? c[i] : six; // ut compute } else if (flag_act == 4) { // leakey relu c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; } -- GitLab