提交 2562652a 编写于 作者: C chenjiaoAngel

fix relu6 bug

上级 e0d94147
...@@ -155,10 +155,10 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -155,10 +155,10 @@ bool gemv_int8_oth(const int8_t* A,
const float* scale, const float* scale,
bool is_bias, bool is_bias,
const float* bias, const float* bias,
bool flag_act, bool flag_act,
lite_api::ActivationType act, lite_api::ActivationType act,
float six, float six,
float alpha) { float alpha) {
if (transA) { if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now"; LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false; return false;
......
...@@ -155,11 +155,12 @@ bool test_gemv_int8(bool tra, ...@@ -155,11 +155,12 @@ bool test_gemv_int8(bool tra,
1, 1,
1, 1,
tc_basic_fp32.numel()); tc_basic_fp32.numel());
if (flag_act == 2) { // relu6 if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic_int8.numel(); i++) { for (int i = 0; i < tc_basic_int8.numel(); i++) {
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i];
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
} }
} }
} }
Timer t0; Timer t0;
//! compute //! compute
......
...@@ -108,6 +108,11 @@ bool test_sgemv(bool tra, ...@@ -108,6 +108,11 @@ bool test_sgemv(bool tra,
flag_act, flag_act,
six, six,
alpha); alpha);
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic.numel(); i++) {
dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i];
}
}
} }
paddle::lite::profile::Timer t0; paddle::lite::profile::Timer t0;
//! compute //! compute
......
...@@ -202,7 +202,7 @@ static void basic_gemv(int m, ...@@ -202,7 +202,7 @@ static void basic_gemv(int m,
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6 } else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
c[i] = c[i] < six ? c[i] : six; // c[i] = c[i] < six ? c[i] : six; // ut compute
} else if (flag_act == 4) { // leakey relu } else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册