From e0d94147571f45e7f0a55916d60f8149dc98364e Mon Sep 17 00:00:00 2001 From: chenjiaoAngel Date: Thu, 7 May 2020 05:33:02 -0400 Subject: [PATCH] fix v7 build bug --- lite/backends/arm/math/gemm_prepacked_int8.cc | 6 ++-- lite/backends/arm/math/gemv_arm_int8.cc | 35 ++++++++++--------- lite/tests/math/conv_int8_compute_test.cc | 20 +++++++++-- lite/tests/math/gemv_int8_compute_test.cc | 6 ++++ 4 files changed, 44 insertions(+), 23 deletions(-) diff --git a/lite/backends/arm/math/gemm_prepacked_int8.cc b/lite/backends/arm/math/gemm_prepacked_int8.cc index 8c079c0c2f..61101c861d 100644 --- a/lite/backends/arm/math/gemm_prepacked_int8.cc +++ b/lite/backends/arm/math/gemm_prepacked_int8.cc @@ -1919,7 +1919,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, "vmax.f32 q3,q3, q15\n" /* relu */ \ "vmax.f32 q4,q4, q15\n" /* relu */ \ "vmax.f32 q5,q5, q15\n" /* relu */ \ - "b: 9f\n" + "b 9f\n" #define GEMM_INT8_RELU6 \ /* do relu6 */ \ @@ -1944,7 +1944,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, "vmin.f32 q3,q3, q14\n" /* relu6 */ \ "vmin.f32 q4,q4, q14\n" /* relu6 */ \ "vmin.f32 q5,q5, q14\n" /* relu6 */ \ - "b: 9f\n" + "b 9f\n" #define GEMM_INT8_LEAKY_RELU \ /* do relu6 */ \ @@ -1975,7 +1975,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, "vmul.f32 q11, q5, q14 @ vmulq_f32 \n" \ "vbif q4, q7, q6 @ choose \n" \ "vbif q5, q11, q10 @ choose \n" \ - "9: \n" + "9: \n" #define GEMM_INT8_FP32_OUT \ GEMM_INT8_TRANS_INT32_TO_FP32 \ diff --git a/lite/backends/arm/math/gemv_arm_int8.cc b/lite/backends/arm/math/gemv_arm_int8.cc index 03bf380c4c..8978e61c70 100644 --- a/lite/backends/arm/math/gemv_arm_int8.cc +++ b/lite/backends/arm/math/gemv_arm_int8.cc @@ -112,33 +112,34 @@ inline void write_gemv_out(const int* in, float alpha) { if (bias) { for (int i = 0; i < size; ++i) { - out[0] = - saturate_cast(roundf(*(in++) * *(scale++) + *(bias++))); - out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127 + float tmp = *(in++) * *(scale++) + *(bias++); if (flag_act) { if (act == lite_api::ActivationType::kRelu) { - out[0] = out[0] > 0.f ? out[0] : 0.f; + tmp = tmp > 0.f ? tmp : 0.f; } else if (act == lite_api::ActivationType::kRelu6) { - out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f; + tmp = tmp > 0.f ? (tmp > six ? six : tmp) : 0.f; } else if (act == lite_api::ActivationType::kLeakyRelu) { - out[0] = out[0] > 0.f ? out[0] : out[0] * alpha; + tmp = tmp > 0.f ? tmp : (tmp * alpha); } } + out[0] = saturate_cast(roundf(tmp)); + out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127 out++; } } else { for (int i = 0; i < size; ++i) { - out[0] = saturate_cast(roundf(*(in++) * *(scale++))); - out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127 + float tmp = *(in++) * *(scale++); if (flag_act) { if (act == lite_api::ActivationType::kRelu) { - out[0] = out[0] > 0.f ? out[0] : 0.f; + tmp = tmp > 0.f ? tmp : 0.f; } else if (act == lite_api::ActivationType::kRelu6) { - out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f; + tmp = tmp > 0.f ? (tmp > six ? six : tmp) : 0.f; } else if (act == lite_api::ActivationType::kLeakyRelu) { - out[0] = out[0] > 0.f ? out[0] : out[0] * alpha; + tmp = tmp > 0.f ? tmp : tmp * alpha; } } + out[0] = saturate_cast(roundf(tmp)); + out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127 out++; } } @@ -711,14 +712,14 @@ bool gemv_int8(const int8_t* A, #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) if (ctx->has_dot()) { return gemv_int8_sdot( - A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); + A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha); } else { return gemv_int8_oth( - A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); + A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha); } #else return gemv_int8_oth( - A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); + A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha); #endif } @@ -740,14 +741,14 @@ bool gemv_int8(const int8_t* A, #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) if (ctx->has_dot()) { return gemv_int8_sdot( - A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); + A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha); } else { return gemv_int8_oth( - A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); + A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha); } #else return gemv_int8_oth( - A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); + A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha); #endif } diff --git a/lite/tests/math/conv_int8_compute_test.cc b/lite/tests/math/conv_int8_compute_test.cc index dc8badff1f..24bdac7a87 100644 --- a/lite/tests/math/conv_int8_compute_test.cc +++ b/lite/tests/math/conv_int8_compute_test.cc @@ -176,13 +176,15 @@ void test_conv_int8(const std::vector& input_dims, act_param.active_type = (paddle::lite_api::ActivationType) flag_act; // 1-relu, 2-relu6, 4-leakyrelu if (flag_act == 1) { - param.fuse_relu = true; + param_fp32_out.fuse_relu = true; + param_int8_out.fuse_relu = true; } else if (flag_act == 2) { act_param.Relu_clipped_coef = six; } else if (flag_act == 4) { - act_param.Leaky_relu_alpha = leakey_relu_scale; + act_param.Leaky_relu_alpha = alpha; } - param.activation_param = act_param; + param_fp32_out.activation_param = act_param; + param_int8_out.activation_param = act_param; } std::vector scale_in{1.f / 127}; @@ -319,6 +321,18 @@ void test_conv_int8(const std::vector& input_dims, 1, 1, dim_out.production()); + if (flag_act == 2) { // relu6 + for (int i = 0; i < dim_out.production(); i++) { + dout_basic_int8[i] = dout_basic_int8[i] > six ? six : dout_basic_int8[i]; + } + } else if (flag_act == 4) { // leakyRelu + for (int i = 0; i < dim_out.production(); i++) { + float tmp = dout_basic_fp32[i] / scale_out.data()[0]; + tmp = tmp > 0 ? tmp : tmp * alpha; + dout_basic_int8[i] = static_cast(roundf(tmp)); + dout_basic_int8[i] = dout_basic_int8[i] < -127 ? -127: dout_basic_int8[i]; + } + } } double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] * diff --git a/lite/tests/math/gemv_int8_compute_test.cc b/lite/tests/math/gemv_int8_compute_test.cc index c069a3a030..1e80464ad1 100644 --- a/lite/tests/math/gemv_int8_compute_test.cc +++ b/lite/tests/math/gemv_int8_compute_test.cc @@ -23,6 +23,7 @@ #include "lite/core/profile/timer.h" #include "lite/core/tensor.h" #include "lite/tests/utils/tensor_utils.h" +#include "lite/backends/arm/math/saturate.h" typedef paddle::lite::Tensor Tensor; using paddle::lite::profile::Timer; @@ -154,6 +155,11 @@ bool test_gemv_int8(bool tra, 1, 1, tc_basic_fp32.numel()); + if (flag_act == 2) { // relu6 + for (int i = 0; i < tc_basic_int8.numel(); i++) { + dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; + } + } } Timer t0; //! compute -- GitLab