提交 fad46f53 编写于 作者: C chenjiaoAngel

fix ut conv+leakyRelu

上级 09820843
......@@ -156,6 +156,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
b_data = bias_.data<float>();
}
bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
if (param.activation_type == "relu") {
flag_relu = true;
}
......@@ -170,8 +172,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
o_data,
nullptr,
false,
false,
scale_.data(),
act_param,
&ctx);
if (param.bias) {
CHECK_EQ(param.bias->numel(), n_);
......@@ -210,8 +212,12 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
b_data = bias_.data<float>();
}
bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
if (param.activation_type == "relu") {
flag_relu = true;
act_param.has_active = true;
act_param.active_type = lite_api::ActivationType::kRelu;
}
if (flag_gemm_) {
CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
......@@ -226,8 +232,8 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
o_data,
nullptr,
false,
flag_relu,
scale_.data(),
act_param,
&ctx);
} else {
for (int i = 0; i < m_; ++i) {
......
......@@ -26,6 +26,7 @@
typedef paddle::lite::Tensor Tensor;
using paddle::lite::profile::Timer;
typedef paddle::lite::operators::ActivationParam ActivationParam;
DEFINE_int32(power_mode,
3,
......@@ -92,6 +93,11 @@ bool test_gemm_int8(bool tra,
std::vector<float> scale_c = {k / 127.f};
std::vector<float> scale_merge_fp32(static_cast<size_t>(m));
std::vector<float> scale_merge_int8(static_cast<size_t>(m));
ActivationParam act_param;
act_param.has_active = has_relu;
if (has_relu){
act_param.active_type = (paddle::lite_api::ActivationType)1;
}
for (int j = 0; j < m; ++j) {
scale_merge_fp32[j] = scale_a[j] * scale_b[0];
scale_merge_int8[j] = scale_merge_fp32[j] / scale_c[0];
......@@ -178,9 +184,9 @@ bool test_gemm_int8(bool tra,
n,
k,
has_bias,
has_relu,
trb,
scale_merge_fp32.data(),
act_param,
&ctx);
}
......@@ -202,9 +208,9 @@ bool test_gemm_int8(bool tra,
n,
k,
has_bias,
has_relu,
trb,
scale_merge_int8.data(),
act_param,
&ctx);
t0.Stop();
}
......@@ -229,9 +235,9 @@ bool test_gemm_int8(bool tra,
n,
k,
has_bias,
has_relu,
trb,
scale_merge_fp32.data(),
act_param,
&ctx);
t0.Stop();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册