diff --git a/lite/kernels/arm/fc_compute.cc b/lite/kernels/arm/fc_compute.cc index 1269a259072b6ae54759794f06040340cc42e15e..102309d86e2977fbf3f8410919d6bf79be5ffe39 100644 --- a/lite/kernels/arm/fc_compute.cc +++ b/lite/kernels/arm/fc_compute.cc @@ -156,6 +156,8 @@ void FcCompute::Run() { b_data = bias_.data(); } bool flag_relu = false; + operators::ActivationParam act_param; + act_param.has_active = false; if (param.activation_type == "relu") { flag_relu = true; } @@ -170,8 +172,8 @@ void FcCompute::Run() { o_data, nullptr, false, - false, scale_.data(), + act_param, &ctx); if (param.bias) { CHECK_EQ(param.bias->numel(), n_); @@ -210,8 +212,12 @@ void FcCompute::Run() { b_data = bias_.data(); } bool flag_relu = false; + operators::ActivationParam act_param; + act_param.has_active = false; if (param.activation_type == "relu") { flag_relu = true; + act_param.has_active = true; + act_param.active_type = lite_api::ActivationType::kRelu; } if (flag_gemm_) { CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " @@ -226,8 +232,8 @@ void FcCompute::Run() { o_data, nullptr, false, - flag_relu, scale_.data(), + act_param, &ctx); } else { for (int i = 0; i < m_; ++i) { diff --git a/lite/tests/math/gemm_int8_compute_test.cc b/lite/tests/math/gemm_int8_compute_test.cc index 377b07b92cbaf36eafcf359c89a2ca3375708847..95edee74278a993d9ce56ff6a960bc44ba142341 100644 --- a/lite/tests/math/gemm_int8_compute_test.cc +++ b/lite/tests/math/gemm_int8_compute_test.cc @@ -26,6 +26,7 @@ typedef paddle::lite::Tensor Tensor; using paddle::lite::profile::Timer; +typedef paddle::lite::operators::ActivationParam ActivationParam; DEFINE_int32(power_mode, 3, @@ -92,6 +93,11 @@ bool test_gemm_int8(bool tra, std::vector scale_c = {k / 127.f}; std::vector scale_merge_fp32(static_cast(m)); std::vector scale_merge_int8(static_cast(m)); + ActivationParam act_param; + act_param.has_active = has_relu; + if (has_relu){ + act_param.active_type = (paddle::lite_api::ActivationType)1; + } for (int j = 0; j < m; ++j) { scale_merge_fp32[j] = scale_a[j] * scale_b[0]; scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0]; @@ -178,9 +184,9 @@ bool test_gemm_int8(bool tra, n, k, has_bias, - has_relu, trb, scale_merge_fp32.data(), + act_param, &ctx); } @@ -202,9 +208,9 @@ bool test_gemm_int8(bool tra, n, k, has_bias, - has_relu, trb, scale_merge_int8.data(), + act_param, &ctx); t0.Stop(); } @@ -229,9 +235,9 @@ bool test_gemm_int8(bool tra, n, k, has_bias, - has_relu, trb, scale_merge_fp32.data(), + act_param, &ctx); t0.Stop(); }