提交 9a273b81 编写于 作者: C chenjiaoAngel

add gemv+relu6/lleakyRelu

上级 8c1cb2af
......@@ -284,8 +284,11 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
scale_group,
flag_bias,
bias_group,
flag_relu,
ctx);
act_param.has_active,
act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else {
gemm_prepack_int8(weights_group,
din_group,
......@@ -526,8 +529,11 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
scale_group,
flag_bias,
bias_group,
flag_relu,
ctx);
act_param.has_active,
act_param.active_type,
ctx,
act_param.Relu_clipped_coef,
act_param.Leaky_relu_alpha);
} else {
gemm_prepack_int8(weights_group,
dB,
......
......@@ -673,6 +673,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"ld1 {v8.4s}, [%[vmax]] \n" /* v8 = -127 */ \
......
......@@ -27,7 +27,10 @@ inline void write_gemv_out(const int* in,
const float* scale,
const float* bias,
int size,
bool is_relu);
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha);
template <>
inline void write_gemv_out(const int* in,
......@@ -35,7 +38,10 @@ inline void write_gemv_out(const int* in,
const float* scale,
const float* bias,
int size,
bool is_relu) {
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
int i = 0;
float32x4_t vzero = vdupq_n_f32(0.f);
for (; i < size - 7; i += 8) {
......@@ -49,9 +55,27 @@ inline void write_gemv_out(const int* in,
float32x4_t vinf1 = vcvtq_f32_s32(vin1);
vout0 = vmlaq_f32(vout0, vinf0, vscale0);
vout1 = vmlaq_f32(vout1, vinf1, vscale1);
if (is_relu) {
if (flag_act) {
if (act == lite_api::ActivationType::kRelu) {
vout0 = vmaxq_f32(vout0, vzero);
vout1 = vmaxq_f32(vout1, vzero);
} else if (act == lite_api::ActivationType::kRelu6) {
float32x4_t vsix = vdupq_n_f32(six);
vout0 = vmaxq_f32(vout0, vzero);
vout1 = vmaxq_f32(vout1, vzero);
vout0 = vminq_f32(vout0, vsix);
vout1 = vminq_f32(vout1, vsix);
}
vout0 = vmaxq_f32(vout0, vzero);
vout1 = vmaxq_f32(vout1, vzero);
} else if (act == lite_api::ActivationType::kLeakyRelu) {
float32x4_t valpha = vdupq_n_f32(alpha);
uint32x4_t maska = vcgeq_f32(vout0, vzero);
uint32x4_t maskb = vcgeq_f32(vout1, vzero);
float32x4_t suma = vmulq_f32(vout0, valpha);
float32x4_t sumb = vmulq_f32(vout1, valpha);
vout0 = vbslq_f32(maska, vout0, suma);
vout1 = vbslq_f32(maskb, vout1, sumb);
}
vst1q_f32(out, vout0);
vst1q_f32(out + 4, vout1);
......@@ -63,7 +87,15 @@ inline void write_gemv_out(const int* in,
for (; i < size; ++i) {
out[0] = *(in++) * *(scale)++;
out[0] += bias ? *(bias++) : 0.f;
out[0] = is_relu ? (out[0] > 0.f ? out[0] : 0.f) : out[0];
if (flag_act) {
if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0.f ? out[0] : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) {
out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) {
out[0] = out[0] > 0.f ? out[0] : out[0] * alpha;
}
}
out++;
}
}
......@@ -74,14 +106,23 @@ inline void write_gemv_out(const int* in,
const float* scale,
const float* bias,
int size,
bool flag_relu) {
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (bias) {
for (int i = 0; i < size; ++i) {
out[0] =
saturate_cast<signed char>(roundf(*(in++) * *(scale++) + *(bias++)));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
if (flag_relu) {
out[0] = out[0] > 0 ? out[0] : 0;
if (flag_act) {
if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0.f ? out[0] : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) {
out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) {
out[0] = out[0] > 0.f ? out[0] : out[0] * alpha;
}
}
out++;
}
......@@ -89,8 +130,14 @@ inline void write_gemv_out(const int* in,
for (int i = 0; i < size; ++i) {
out[0] = saturate_cast<signed char>(roundf(*(in++) * *(scale++)));
out[0] = out[0] < -127 ? -127 : out[0]; // -127 - 127
if (flag_relu) {
out[0] = out[0] > 0 ? out[0] : 0;
if (flag_act) {
if (act == lite_api::ActivationType::kRelu) {
out[0] = out[0] > 0.f ? out[0] : 0.f;
} else if (act == lite_api::ActivationType::kRelu6) {
out[0] = out[0] > 0.f ? (out[0] > six ? six : out[0]) : 0.f;
} else if (act == lite_api::ActivationType::kLeakyRelu) {
out[0] = out[0] > 0.f ? out[0] : out[0] * alpha;
}
}
out++;
}
......@@ -107,7 +154,10 @@ bool gemv_int8_oth(const int8_t* A,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu) {
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false;
......@@ -260,7 +310,7 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[7] += ptr_in[i] * ptr_w7[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu);
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
}
//! deal with remains
......@@ -304,7 +354,7 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu);
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
}
#else // __aarch64__
int out_cnt = M >> 2;
......@@ -398,7 +448,7 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[2] += ptr_in[i] * ptr_w2[i];
ptr_out[3] += ptr_in[i] * ptr_w3[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, is_relu);
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, flag_act, act, six, alpha);
}
//! deal with remains
#pragma omp parallel for
......@@ -439,7 +489,7 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu);
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
}
#endif // __aarch64__
return true;
......@@ -456,7 +506,10 @@ bool gemv_int8_sdot(const int8_t* A,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu) {
bool flag_act,
lite_api::ActivationType act,
float six,
float alpha) {
if (transA) {
LOG(ERROR) << "ERROR: sgemv, transA is not supported now";
return false;
......@@ -594,7 +647,7 @@ bool gemv_int8_sdot(const int8_t* A,
ptr_out[6] += ptr_in[i] * ptr_w6[i];
ptr_out[7] += ptr_in[i] * ptr_w7[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, is_relu);
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
}
//! deal with remains
#pragma omp parallel for
......@@ -634,7 +687,7 @@ bool gemv_int8_sdot(const int8_t* A,
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, is_relu);
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
}
return true;
}
......@@ -650,19 +703,22 @@ bool gemv_int8<float>(const int8_t* A,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu,
const ARMContext* ctx) {
bool flag_act,
lite_api::ActivationType act,
const ARMContext* ctx,
float six,
float alpha) {
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) {
return gemv_int8_sdot<float>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha);
} else {
return gemv_int8_oth<float>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha);
}
#else
return gemv_int8_oth<float>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha);
#endif
}
......@@ -676,19 +732,22 @@ bool gemv_int8<int8_t>(const int8_t* A,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu,
const ARMContext* ctx) {
bool flag_act,
lite_api::ActivationType act,
const ARMContext* ctx,
float six,
float alpha) {
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) {
return gemv_int8_sdot<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha);
} else {
return gemv_int8_oth<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha);
}
#else
return gemv_int8_oth<int8_t>(
A, x, y, transA, M, N, scale, is_bias, bias, is_relu);
A, x, y, transA, M, N, scale, is_bias, bias, flag_act, six, alpha);
#endif
}
......
......@@ -32,8 +32,11 @@ bool gemv_int8(const int8_t* A,
const float* scale,
bool is_bias,
const float* bias,
bool is_relu,
const ARMContext* ctx);
bool flag_act,
lite_api::ActivationType act,
const ARMContext* ctx,
float six = 6.f,
float alpha = 1.f);
} // namespace math
} // namespace arm
......
......@@ -157,8 +157,10 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
}
bool flag_relu = false;
operators::ActivationParam act_param;
lite_api::ActivationType act;
act_param.has_active = false;
if (param.activation_type == "relu") {
act = lite_api::ActivationType::kRelu;
flag_relu = true;
}
if (flag_gemm_) {
......@@ -193,6 +195,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
param.bias != nullptr,
b_data,
flag_relu,
act,
&ctx);
}
}
......@@ -214,10 +217,12 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
bool flag_relu = false;
operators::ActivationParam act_param;
act_param.has_active = false;
lite_api::ActivationType act;
if (param.activation_type == "relu") {
flag_relu = true;
act_param.has_active = true;
act_param.active_type = lite_api::ActivationType::kRelu;
act = lite_api::ActivationType::kRelu;
}
if (flag_gemm_) {
CHECK(!param.bias) << "fc int8 kernel with int8 output using gemm kernel "
......@@ -249,6 +254,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
param.bias != nullptr,
b_data,
flag_relu,
act,
&ctx);
}
}
......
......@@ -53,12 +53,15 @@ DEFINE_int32(stride_w, 1, "stride width");
DEFINE_int32(dila_h, 1, "dilation height");
DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool(flag_relu, true, "do relu");
DEFINE_bool(flag_act, true, "do act");
DEFINE_bool(flag_bias, true, "with bias");
DEFINE_double(clipped_coef, 1.0, "clipped relu coef");
DEFINE_double(leakey_relu_alpha, 8.88, "leakey relu alpha");
typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in,
......@@ -129,9 +132,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
int flag_act,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {
const std::vector<int>& power_mode,
const float six = 6.f,
const float alpha = 1.f) {
paddle::lite::DeviceInfo::Init();
ConvParam param_int8_out;
ConvParam param_fp32_out;
......@@ -142,7 +147,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads,
dilas,
flag_bias,
flag_relu,
flag_act > 0,
&param_int8_out);
get_conv_param<PRECISION(kFloat)>(weight_dim,
......@@ -151,7 +156,7 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads,
dilas,
flag_bias,
flag_relu,
flag_act > 0,
&param_fp32_out);
Tensor weight_fp32;
Tensor bias_fp32;
......@@ -165,6 +170,20 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias);
bias_fp32.CopyDataFrom(*param_int8_out.bias);
}
if (flag_act > 0) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_act; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_act == 1) {
param.fuse_relu = true;
} else if (flag_act == 2) {
act_param.Relu_clipped_coef = six;
} else if (flag_act == 4) {
act_param.Leaky_relu_alpha = leakey_relu_scale;
}
param.activation_param = act_param;
}
std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
......@@ -291,7 +310,9 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
pads[2],
pads[0],
flag_bias,
static_cast<int>(flag_relu));
flag_act,
six,
alpha);
paddle::lite::arm::math::fp32_to_int8(dout_basic_fp32,
dout_basic_int8,
scale_out.data(),
......@@ -364,9 +385,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", group: " << group
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " failed!!\n";
}
}
}
......@@ -423,9 +443,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
<< " failed!!\n";
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " failed!!\n";
}
}
}
......@@ -435,9 +454,8 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
<< ", " << pads[3] << ", stride: " << strides[0] << ", "
<< strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls
<< " successed!!\n";
<< ", act: " << flag_act << ", threads: " << th
<< ", power_mode: " << cls << " successed!!\n";
}
}
}
......@@ -452,9 +470,11 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
const std::vector<int>& pads,
const std::vector<int>& dilas,
bool flag_bias,
bool flag_relu,
int flag_act,
const std::vector<int>& thread_num,
const std::vector<int>& power_mode) {}
const std::vector<int>& power_mode,
float six = 6.f,
float alpha = 1.f) {}
#endif // LITE_WITH_ARM
#if 1 /// 3x3dw
......@@ -463,7 +483,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& c : {1, 3, 5, 8, 16, 32}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 3, 3});
......@@ -479,9 +499,11 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{4},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
}
}
......@@ -497,7 +519,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act: {0, 1, 2, 4}) {
for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
......@@ -513,9 +535,11 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
{pad, pad, pad, pad},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{1, 4},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
}
}
......@@ -532,7 +556,7 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
for (auto& cout : {1, 5, 17}) {
for (auto& g : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
if (cin % g != 0 || cout % g != 0) {
continue;
......@@ -550,9 +574,11 @@ TEST(TestConv1x1s1Int8, test_conv1x1s1) {
{0, 0, 0, 0},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{4},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
}
}
......@@ -572,7 +598,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
......@@ -587,9 +613,11 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{4},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
}
}
......@@ -612,7 +640,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
for (auto& pad_left : {1, 2}) {
for (auto& pad_right : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
std::vector<DDim> dims;
DDim weights_dim({cout, cin, 3, 3});
for (auto& batch : {1, 2}) {
......@@ -627,9 +655,11 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
{pad_top, pad_bottom, pad_left, pad_right},
{1, 1},
flag_bias,
flag_relu,
flag_act,
{4},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
}
}
......@@ -657,7 +687,7 @@ TEST(TestConvRandInt8, test_conv_rand) {
for (auto& pad_right : {0, 1, 2}) {
for (auto& dila : {1, 2}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
if (cin % g != 0 || cout % g != 0) {
break;
}
......@@ -676,9 +706,11 @@ TEST(TestConvRandInt8, test_conv_rand) {
{pad_top, pad_bottom, pad_left, pad_right},
{dila, dila},
flag_bias,
flag_relu,
flag_act,
{4},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
}
}
......@@ -713,8 +745,10 @@ TEST(TestConvCustomInt8, test_conv_custom_size) {
{FLAGS_pad_h, FLAGS_pad_h, FLAGS_pad_w, FLAGS_pad_w},
{FLAGS_dila_h, FLAGS_dila_w},
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_flag_act,
{FLAGS_threads},
{FLAGS_power_mode});
{FLAGS_power_mode},
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
}
#endif // custom
......@@ -45,11 +45,20 @@ DEFINE_int32(N, 512, "gemv: N");
DEFINE_bool(traA, false, "gemv: A transpose");
DEFINE_bool(flag_relu, false, "do relu");
DEFINE_int32(flag_act, 0, "do act");
DEFINE_bool(flag_bias, false, "with bias");
DEFINE_double(leakey_relu_alpha, 1.0, "leakey relu alpha");
DEFINE_double(clipped_coef, 6.0, "clipped relu coef");
bool test_gemv_int8(
bool tra, int m, int n, bool has_bias, bool has_relu, int cls, int ths) {
bool test_gemv_int8(bool tra,
int m,
int n,
bool has_bias,
int flag_act,
int cls,
int ths,
float six = 6.f,
float alpha = 1.f) {
Tensor ta;
Tensor tb;
Tensor tc_int8;
......@@ -90,7 +99,7 @@ bool test_gemv_int8(
LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n
<< ", transA: " << (tra ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
auto da = ta.mutable_data<int8_t>();
......@@ -101,6 +110,16 @@ bool test_gemv_int8(
auto dc_basic_fp32 = tc_basic_fp32.mutable_data<float>();
auto dbias = tbias.mutable_data<float>();
paddle::lite_api::ActivationType act =
paddle::lite_api::ActivationType::kIndentity;
if (flag_act == 1) {
act = paddle::lite_api::ActivationType::kRelu;
} else if (flag_act == 2) {
act = paddle::lite_api::ActivationType::kRelu6;
} else if (flag_act == 4) {
act = paddle::lite_api::ActivationType::kLeakyRelu;
}
if (FLAGS_check_result) {
Tensor ta_fp32;
Tensor tb_fp32;
......@@ -126,7 +145,9 @@ bool test_gemv_int8(
0.f,
false,
has_bias,
has_relu);
flag_act,
six,
alpha);
paddle::lite::arm::math::fp32_to_int8(dc_basic_fp32,
dc_basic_int8,
scale_c.data(),
......@@ -152,8 +173,11 @@ bool test_gemv_int8(
scale_merge_fp32.data(),
has_bias,
dbias,
has_relu,
&ctx);
flag_act > 0,
act,
&ctx,
six,
alpha);
}
/// int8 output compute
......@@ -175,8 +199,11 @@ bool test_gemv_int8(
scale_merge_fp32.data(),
has_bias,
dbias,
has_relu,
&ctx);
flag_act > 0,
act,
&ctx,
six,
alpha);
t0.Stop();
}
LOG(INFO) << "gemv_int8_int8 output: M: " << m << ", N: " << n
......@@ -201,8 +228,11 @@ bool test_gemv_int8(
scale_merge_int8.data(),
has_bias,
dbias_int8,
has_relu,
&ctx);
flag_act > 0,
act,
&ctx,
six,
alpha);
t0.Stop();
}
LOG(INFO) << "gemm_int8_fp32 output: M: " << m << ", N: " << n
......@@ -289,20 +319,29 @@ TEST(TestLiteGemvInt8, gemv_prepacked_int8) {
for (auto& n : {1, 3, 13, 141, 512, 789}) {
for (auto& tra : {false}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& flag_act : {0, 1, 2, 4}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_gemv_int8(
tra, m, n, has_bias, has_relu, FLAGS_power_mode, th);
float six = 6.f;
float alpha = 8.88f;
auto flag = test_gemv_int8(tra,
m,
n,
has_bias,
flag_act,
FLAGS_power_mode,
th,
six,
alpha);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< ", act: " << flag_act
<< ", trans A: " << (tra ? "true" : "false")
<< " failed\n";
}
......@@ -323,15 +362,17 @@ TEST(TestGemvInt8Custom, gemv_prepacked_int8_custom) {
FLAGS_M,
FLAGS_N,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_flag_act,
FLAGS_power_mode,
FLAGS_threads);
FLAGS_threads,
FLAGS_clipped_coef,
FLAGS_leakey_relu_alpha);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
<< ", act: " << FLAGS_flag_act << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", trans A: " << FLAGS_traA << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " passed!!";
<< ", act: " << FLAGS_flag_act << " passed!!";
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册