提交 fad46f53 编写于 作者: C chenjiaoAngel

fix ut conv+leakyRelu

上级 09820843
...@@ -156,6 +156,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -156,6 +156,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false; bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
if (param.activation_type == "relu") { if (param.activation_type == "relu") {
flag_relu = true; flag_relu = true;
} }
...@@ -170,8 +172,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -170,8 +172,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data, o_data,
nullptr, nullptr,
false, false,
false,
scale_.data(), scale_.data(),
act_param,
&ctx); &ctx);
if (param.bias) { if (param.bias) {
CHECK_EQ(param.bias->numel(), n_); CHECK_EQ(param.bias->numel(), n_);
...@@ -210,8 +212,12 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -210,8 +212,12 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
} }
bool flag_relu = false; bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
if (param.activation_type == "relu") { if (param.activation_type == "relu") {
flag_relu = true; flag_relu = true;
act_param.has_active = true;
act_param.active_type = lite_api::ActivationType::kRelu;
} }
if (flag_gemm_) { if (flag_gemm_) {
CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel " CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
...@@ -226,8 +232,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -226,8 +232,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
o_data, o_data,
nullptr, nullptr,
false, false,
flag_relu,
scale_.data(), scale_.data(),
act_param,
&ctx); &ctx);
} else { } else {
for (int i = 0; i < m_; ++i) { for (int i = 0; i < m_; ++i) {
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
typedef paddle::lite::operators::ActivationParam ActivationParam;
DEFINE_int32(power_mode, DEFINE_int32(power_mode,
3, 3,
...@@ -92,6 +93,11 @@ bool test_gemm_int8(bool tra, ...@@ -92,6 +93,11 @@ bool test_gemm_int8(bool tra,
std::vector<float> scale_c = {k / 127.f}; std::vector<float> scale_c = {k / 127.f};
std::vector<float> scale_merge_fp32(static_cast<size_t>(m)); std::vector<float> scale_merge_fp32(static_cast<size_t>(m));
std::vector<float> scale_merge_int8(static_cast<size_t>(m)); std::vector<float> scale_merge_int8(static_cast<size_t>(m));
ActivationParam act_param;
act_param.has_active = has_relu;
if (has_relu){
act_param.active_type = (paddle::lite_api::ActivationType)1;
}
for (int j = 0; j < m; ++j) { for (int j = 0; j < m; ++j) {
scale_merge_fp32[j] = scale_a[j] * scale_b[0]; scale_merge_fp32[j] = scale_a[j] * scale_b[0];
scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0]; scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0];
...@@ -178,9 +184,9 @@ bool test_gemm_int8(bool tra, ...@@ -178,9 +184,9 @@ bool test_gemm_int8(bool tra,
n, n,
k, k,
has_bias, has_bias,
has_relu,
trb, trb,
scale_merge_fp32.data(), scale_merge_fp32.data(),
act_param,
&ctx); &ctx);
} }
...@@ -202,9 +208,9 @@ bool test_gemm_int8(bool tra, ...@@ -202,9 +208,9 @@ bool test_gemm_int8(bool tra,
n, n,
k, k,
has_bias, has_bias,
has_relu,
trb, trb,
scale_merge_int8.data(), scale_merge_int8.data(),
act_param,
&ctx); &ctx);
t0.Stop(); t0.Stop();
} }
...@@ -229,9 +235,9 @@ bool test_gemm_int8(bool tra, ...@@ -229,9 +235,9 @@ bool test_gemm_int8(bool tra,
n, n,
k, k,
has_bias, has_bias,
has_relu,
trb, trb,
scale_merge_fp32.data(), scale_merge_fp32.data(),
act_param,
&ctx); &ctx);
t0.Stop(); t0.Stop();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册