diff --git a/lite/backends/arm/math/gemv_arm_int8.cc b/lite/backends/arm/math/gemv_arm_int8.cc index 8978e61c70821235bbe6af39a1e58f4ce9a6fa40..ea271d08bd92ce4224a2b89db88f54199cc31207 100644 --- a/lite/backends/arm/math/gemv_arm_int8.cc +++ b/lite/backends/arm/math/gemv_arm_int8.cc @@ -155,10 +155,10 @@ bool gemv_int8_oth(const int8_t* A, const float* scale, bool is_bias, const float* bias, - bool flag_act, - lite_api::ActivationType act, - float six, - float alpha) { + bool flag_act, + lite_api::ActivationType act, + float six, + float alpha) { if (transA) { LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; return false; diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc index 1e80464ad120e16bd919d78e35184a6e8bb89c6e..9e29617bb4ee6a64fee112eed7d95ebbd81c57da 100644 --- a/lite/tests/math/gemv_int8_compute_test.cc +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -155,11 +155,12 @@ bool test_gemv_int8(bool tra, 1, 1, tc_basic_fp32.numel()); - if (flag_act == 2) { // relu6 + if (flag_act == 2) { // relu6 for (int i = 0; i < tc_basic_int8.numel(); i++) { - dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; + dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i]; + dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; } - } + } } Timer t0; //! compute diff --git a/lite/tests/math/sgemv_compute_test.cc b/lite/tests/math/sgemv_compute_test.cc index 91a1fe1770dfa3eeb3f3b94fcd2361f1c1634b1e..88277c06768c2cdec223ba7d38f1ae749e0ffdd8 100644 --- a/lite/tests/math/sgemv_compute_test.cc +++ b/lite/tests/math/sgemv_compute_test.cc @@ -108,6 +108,11 @@ bool test_sgemv(bool tra, flag_act, six, alpha); + if (flag_act == 2) { // relu6 + for (int i = 0; i < tc_basic.numel(); i++) { + dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i]; + } + } } paddle::lite::profile::Timer t0; //! compute diff --git a/lite/tests/utils/naive_math_impl.h b/lite/tests/utils/naive_math_impl.h index e5ef77ca061d31a0b9b735d49cda9bbeda53c294..916246f70807b14bdafe1269d8e97698b26a4321 100644 --- a/lite/tests/utils/naive_math_impl.h +++ b/lite/tests/utils/naive_math_impl.h @@ -202,7 +202,7 @@ static void basic_gemv(int m, c[i] = tmp > (type2)0 ? tmp : (type2)0; } else if (flag_act == 2) { // relu 6 c[i] = tmp > (type2)0 ? tmp : (type2)0; - c[i] = c[i] < six ? c[i] : six; + // c[i] = c[i] < six ? c[i] : six; // ut compute } else if (flag_act == 4) { // leakey relu c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; }