提交 e0d94147 编写于 作者: C chenjiaoAngel

fix v7 build bug

上级 9a273b81
...@@ -1919,7 +1919,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1919,7 +1919,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmax.f32 q3,q3, q15\n" /* relu */ \ "vmax.f32 q3,q3, q15\n" /* relu */ \
"vmax.f32 q4,q4, q15\n" /* relu */ \ "vmax.f32 q4,q4, q15\n" /* relu */ \
"vmax.f32 q5,q5, q15\n" /* relu */ \ "vmax.f32 q5,q5, q15\n" /* relu */ \
"b: 9f\n" "b 9f\n"
#define GEMM_INT8_RELU6 \ #define GEMM_INT8_RELU6 \
/* do relu6 */ \ /* do relu6 */ \
...@@ -1944,7 +1944,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1944,7 +1944,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmin.f32 q3,q3, q14\n" /* relu6 */ \ "vmin.f32 q3,q3, q14\n" /* relu6 */ \
"vmin.f32 q4,q4, q14\n" /* relu6 */ \ "vmin.f32 q4,q4, q14\n" /* relu6 */ \
"vmin.f32 q5,q5, q14\n" /* relu6 */ \ "vmin.f32 q5,q5, q14\n" /* relu6 */ \
"b: 9f\n" "b 9f\n"
#define GEMM_INT8_LEAKY_RELU \ #define GEMM_INT8_LEAKY_RELU \
/* do relu6 */ \ /* do relu6 */ \
...@@ -1975,7 +1975,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr, ...@@ -1975,7 +1975,7 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmul.f32 q11, q5, q14 @ vmulq_f32 \n" \ "vmul.f32 q11, q5, q14 @ vmulq_f32 \n" \
"vbif q4, q7, q6 @ choose \n" \ "vbif q4, q7, q6 @ choose \n" \
"vbif q5, q11, q10 @ choose \n" \ "vbif q5, q11, q10 @ choose \n" \
"9: \n" "9: \n"
#define GEMM_INT8_FP32_OUT \ #define GEMM_INT8_FP32_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \ GEMM_INT8_TRANS_INT32_TO_FP32 \
......
...@@ -112,33 +112,34 @@ inline void write_gemv_out(const int* in, ...@@ -112,33 +112,34 @@ inline void write_gemv_out(const int* in,
float alpha) { float alpha) {
if (bias) { if (bias) {
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
out[0] = float tmp = *(in++) * *(scale++) + *(bias++);
saturate_cast<signed char>(roundf(*(in++) * *(scale++) + *(bias++)));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
if (flag_act) { if (flag_act) {
if (act == lite_api::ActivationType::kRelu) { if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0.f ? out[0] : 0.f; tmp = tmp > 0.f ? tmp : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) { } else if (act == lite_api::ActivationType::kRelu6) {
out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f; tmp = tmp > 0.f ? (tmp > six ? six : tmp) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) { } else if (act == lite_api::ActivationType::kLeakyRelu) {
out[0] = out[0] > 0.f ? out[0] : out[0] * alpha; tmp = tmp > 0.f ? tmp : (tmp * alpha);
} }
} }
out[0] = saturate_cast<signed char>(roundf(tmp));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
out++; out++;
} }
} else { } else {
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
out[0] = saturate_cast<signed char>(roundf(*(in++) * *(scale++))); float tmp = *(in++) * *(scale++);
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
if (flag_act) { if (flag_act) {
if (act == lite_api::ActivationType::kRelu) { if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0.f ? out[0] : 0.f; tmp = tmp > 0.f ? tmp : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) { } else if (act == lite_api::ActivationType::kRelu6) {
out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f; tmp = tmp > 0.f ? (tmp > six ? six : tmp) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) { } else if (act == lite_api::ActivationType::kLeakyRelu) {
out[0] = out[0] > 0.f ? out[0] : out[0] * alpha; tmp = tmp > 0.f ? tmp : tmp * alpha;
} }
} }
out[0] = saturate_cast<signed char>(roundf(tmp));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
out++; out++;
} }
} }
...@@ -711,14 +712,14 @@ bool gemv_int8<float>(const int8_t* A, ...@@ -711,14 +712,14 @@ bool gemv_int8<float>(const int8_t* A,
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
return gemv_int8_sdot<float>( return gemv_int8_sdot<float>(
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} else { } else {
return gemv_int8_oth<float>( return gemv_int8_oth<float>(
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} }
#else #else
return gemv_int8_oth<float>( return gemv_int8_oth<float>(
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
#endif #endif
} }
...@@ -740,14 +741,14 @@ bool gemv_int8<int8_t>(const int8_t* A, ...@@ -740,14 +741,14 @@ bool gemv_int8<int8_t>(const int8_t* A,
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
return gemv_int8_sdot<int8_t>( return gemv_int8_sdot<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} else { } else {
return gemv_int8_oth<int8_t>( return gemv_int8_oth<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
} }
#else #else
return gemv_int8_oth<int8_t>( return gemv_int8_oth<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha); A, x, y, transA, M, N, scale, is_bias, bias, flag_act, act, six, alpha);
#endif #endif
} }
......
...@@ -176,13 +176,15 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -176,13 +176,15 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
act_param.active_type = (paddle::lite_api::ActivationType) act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) { if (flag_act == 1) {
param.fuse_relu = true; param_fp32_out.fuse_relu = true;
param_int8_out.fuse_relu = true;
} else if (flag_act == 2) { } else if (flag_act == 2) {
act_param.Relu_clipped_coef = six; act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) { } else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale; act_param.Leaky_relu_alpha = alpha;
} }
param.activation_param = act_param; param_fp32_out.activation_param = act_param;
param_int8_out.activation_param = act_param;
} }
std::vector<float> scale_in{1.f / 127}; std::vector<float> scale_in{1.f / 127};
...@@ -319,6 +321,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -319,6 +321,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
1, 1,
1, 1,
dim_out.production()); dim_out.production());
if (flag_act == 2) { // relu6
for (int i = 0; i < dim_out.production(); i++) {
dout_basic_int8[i] = dout_basic_int8[i] > six ? six : dout_basic_int8[i];
}
} else if (flag_act == 4) { // leakyRelu
for (int i = 0; i < dim_out.production(); i++) {
float tmp = dout_basic_fp32[i] / scale_out.data()[0];
tmp = tmp > 0 ? tmp : tmp * alpha;
dout_basic_int8[i] = static_cast<int8_t>(roundf(tmp));
dout_basic_int8[i] = dout_basic_int8[i] < -127 ? -127: dout_basic_int8[i];
}
}
} }
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] * double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "lite/core/profile/timer.h" #include "lite/core/profile/timer.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h" #include "lite/tests/utils/tensor_utils.h"
#include "lite/backends/arm/math/saturate.h"
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
...@@ -154,6 +155,11 @@ bool test_gemv_int8(bool tra, ...@@ -154,6 +155,11 @@ bool test_gemv_int8(bool tra,
1, 1,
1, 1,
tc_basic_fp32.numel()); tc_basic_fp32.numel());
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic_int8.numel(); i++) {
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
}
}
} }
Timer t0; Timer t0;
//! compute //! compute
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册