提交 2562652a 编写于 作者: C chenjiaoAngel

fix relu6 bug

上级 e0d94147
...@@ -157,6 +157,7 @@ bool test_gemv_int8(bool tra, ...@@ -157,6 +157,7 @@ bool test_gemv_int8(bool tra,
tc_basic_fp32.numel()); tc_basic_fp32.numel());
if (flag_act == 2) { // relu6 if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic_int8.numel(); i++) { for (int i = 0; i < tc_basic_int8.numel(); i++) {
dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i];
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
} }
} }
......
...@@ -108,6 +108,11 @@ bool test_sgemv(bool tra, ...@@ -108,6 +108,11 @@ bool test_sgemv(bool tra,
flag_act, flag_act,
six, six,
alpha); alpha);
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic.numel(); i++) {
dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i];
}
}
} }
paddle::lite::profile::Timer t0; paddle::lite::profile::Timer t0;
//! compute //! compute
......
...@@ -202,7 +202,7 @@ static void basic_gemv(int m, ...@@ -202,7 +202,7 @@ static void basic_gemv(int m,
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6 } else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0; c[i] = tmp > (type2)0 ? tmp : (type2)0;
c[i] = c[i] < six ? c[i] : six; // c[i] = c[i] < six ? c[i] : six; // ut compute
} else if (flag_act == 4) { // leakey relu } else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp; c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册