提交 5214a2c0 编写于 作者: C chenjiaoAngel

fix relu6 problem, test=develop

上级 69cc2315
......@@ -79,6 +79,10 @@ void GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
flag_trans_bias_ = true;
}
//! update relu6 parameter
if (param.activation_param.active_type == lite_api::ActivationType::kRelu6){
param.activation_param.Relu_clipped_coef = param.activation_param.Relu_clipped_coef / param.output_scale;
}
}
template <>
......
......@@ -221,11 +221,6 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
flag_act,
six,
leakey_relu_scale);
if (flag_act == 2) { // relu6
for (int i = 0; i < dim_out.production(); i++) {
dout_basic[i] = dout_basic[i] > six ? six : dout_basic[i];
}
}
}
/// warm up
for (int i = 0; i < FLAGS_warmup; ++i) {
......
......@@ -321,14 +321,6 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
1,
1,
dim_out.production());
if (flag_act == 2) { // relu6
for (int i = 0; i < dim_out.production(); i++) {
dout_basic_int8[i] =
dout_basic_int8[i] > six ? six : dout_basic_int8[i];
dout_basic_fp32[i] =
dout_basic_fp32[i] > six ? six : dout_basic_fp32[i];
}
}
}
double gops = 2.0 * dim_out.production() * dim_in[1] * weight_dim[2] *
weight_dim[3] / group;
......
......@@ -153,12 +153,6 @@ bool test_gemv_int8(bool tra,
1,
1,
tc_basic_fp32.numel());
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic_int8.numel(); i++) {
dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i];
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
}
}
}
Timer t0;
//! compute
......@@ -324,7 +318,7 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) {
for (auto& n : {1, 3, 13, 141, 512, 789}) {
for (auto& tra : {false}) {
for (auto& has_bias : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& flag_act : {0, 1, 4}) {
for (auto& th : {1, 2, 4}) {
float six = 6.f;
float alpha = 8.88f;
......
......@@ -108,11 +108,6 @@ bool test_sgemv(bool tra,
flag_act,
six,
alpha);
if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic.numel(); i++) {
dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i];
}
}
}
paddle::lite::profile::Timer t0;
//! compute
......
......@@ -202,7 +202,7 @@ static void basic_gemv(int m,
c[i] = tmp > (type2)0 ? tmp : (type2)0;
} else if (flag_act == 2) { // relu 6
c[i] = tmp > (type2)0 ? tmp : (type2)0;
// c[i] = c[i] < six ? c[i] : six; // ut compute
c[i] = c[i] < six ? c[i] : six; // ut compute
} else if (flag_act == 4) { // leakey relu
c[i] = tmp < (type2)0 ? (type2)(tmp * leakey_relu_alpha) : tmp;
}
......@@ -301,9 +301,9 @@ static void conv_basic(const Dtype1* din,
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
// dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
// ? dst_data_ref[out_idx]
// : (Dtype2)six;
dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
? dst_data_ref[out_idx]
: (Dtype2)six;
} else if (act_type == 4) {
dst_data_ref[out_idx] =
dst_data_ref[out_idx] > (Dtype2)0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册