提交 2562652a 编写于 作者: C chenjiaoAngel

fix relu6 bug

上级 e0d94147
......@@ -155,10 +155,10 @@ bool gemv_int8_oth(const int8_t* A,
const float* scale,
bool is_bias,
const float* bias,
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false;
......
......@@ -155,11 +155,12 @@ bool test_gemv_int8(bool tra,
1,
1,
tc_basic_fp32.numel());
if (flag_act == 2) { // relu6
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic_int8.numel(); i++) {
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i];
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
}
}
}
}
Timer t0;
//! compute
......
......@@ -108,6 +108,11 @@ bool test_sgemv(bool tra,
flag_act,
six,
alpha);
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic.numel(); i++) {
dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i];
}
}
}
paddle::lite::profile::Timer t0;
//! compute
......
......@@ -202,7 +202,7 @@ static void basic_gemv(int m,
c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0;
c[i] = c[i] < six ? c[i] : six;
// c[i] = c[i] < six ? c[i] : six; // ut compute
} else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册