提交 e5066279 编写于 作者: C chenjiaoAngel

fix format. test=develop

上级 3aec2316
...@@ -4222,15 +4222,48 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -4222,15 +4222,48 @@ void gemm_prepack_int8(const int8_t* A_packed,
} }
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
gemm_prepack_sdot_int8<float32_t>( gemm_prepack_sdot_int8<float32_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} else { } else {
gemm_prepack_oth_int8<float32_t>( gemm_prepack_oth_int8<float32_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} }
#else #else
gemm_prepack_oth_int8<float32_t>( gemm_prepack_oth_int8<float32_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
#endif #endif
} }
...@@ -4271,11 +4304,33 @@ void gemm_prepack_int8(const int8_t* A_packed, ...@@ -4271,11 +4304,33 @@ void gemm_prepack_int8(const int8_t* A_packed,
} }
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD) #if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) { if (ctx->has_dot()) {
gemm_prepack_sdot_int8<int8_t>( gemm_prepack_sdot_int8<int8_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} else { } else {
gemm_prepack_oth_int8<int8_t>( gemm_prepack_oth_int8<int8_t>(A_packed,
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx); B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} }
#else #else
gemm_prepack_oth_int8<int8_t>( gemm_prepack_oth_int8<int8_t>(
......
...@@ -311,7 +311,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -311,7 +311,8 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[7] += ptr_in[i] * ptr_w7[i]; ptr_out[7] += ptr_in[i] * ptr_w7[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
} }
//! deal with remains //! deal with remains
...@@ -355,7 +356,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -355,7 +356,8 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
} }
#else // __aarch64__ #else // __aarch64__
int out_cnt = M >> 2; int out_cnt = M >> 2;
...@@ -449,7 +451,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -449,7 +451,8 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[2] += ptr_in[i] * ptr_w2[i]; ptr_out[2] += ptr_in[i] * ptr_w2[i];
ptr_out[3] += ptr_in[i] * ptr_w3[i]; ptr_out[3] += ptr_in[i] * ptr_w3[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, flag_act, act, six, alpha); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 4, flag_act, act, six, alpha);
} }
//! deal with remains //! deal with remains
#pragma omp parallel for #pragma omp parallel for
...@@ -490,7 +493,8 @@ bool gemv_int8_oth(const int8_t* A, ...@@ -490,7 +493,8 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
} }
#endif // __aarch64__ #endif // __aarch64__
return true; return true;
...@@ -648,7 +652,8 @@ bool gemv_int8_sdot(const int8_t* A, ...@@ -648,7 +652,8 @@ bool gemv_int8_sdot(const int8_t* A,
ptr_out[6] += ptr_in[i] * ptr_w6[i]; ptr_out[6] += ptr_in[i] * ptr_w6[i];
ptr_out[7] += ptr_in[i] * ptr_w7[i]; ptr_out[7] += ptr_in[i] * ptr_w7[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
} }
//! deal with remains //! deal with remains
#pragma omp parallel for #pragma omp parallel for
...@@ -688,7 +693,8 @@ bool gemv_int8_sdot(const int8_t* A, ...@@ -688,7 +693,8 @@ bool gemv_int8_sdot(const int8_t* A,
for (int i = 0; i < tail; ++i) { for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i]; ptr_out[0] += ptr_in[i] * ptr_w0[i];
} }
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha); write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
} }
return true; return true;
} }
......
...@@ -221,11 +221,11 @@ void test_conv_fp32(const std::vector<DDim>& input_dims, ...@@ -221,11 +221,11 @@ void test_conv_fp32(const std::vector<DDim>& input_dims,
flag_act, flag_act,
six, six,
leakey_relu_scale); leakey_relu_scale);
if (flag_act == 2) { // relu6 if (flag_act == 2) { // relu6
for (int i = 0; i < dim_out.production(); i++) { for (int i = 0; i < dim_out.production(); i++) {
dout_basic[i] = dout_basic[i] > six ? six : dout_basic[i]; dout_basic[i] = dout_basic[i] > six ? six : dout_basic[i];
} }
} }
} }
/// warm up /// warm up
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
......
...@@ -321,10 +321,12 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -321,10 +321,12 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
1, 1,
1, 1,
dim_out.production()); dim_out.production());
if (flag_act == 2) { // relu6 if (flag_act == 2) { // relu6
for (int i = 0; i < dim_out.production(); i++) { for (int i = 0; i < dim_out.production(); i++) {
dout_basic_int8[i] = dout_basic_int8[i] > six ? six : dout_basic_int8[i]; dout_basic_int8[i] =
dout_basic_fp32[i] = dout_basic_fp32[i] > six ? six : dout_basic_fp32[i]; dout_basic_int8[i] > six ? six : dout_basic_int8[i];
dout_basic_fp32[i] =
dout_basic_fp32[i] > six ? six : dout_basic_fp32[i];
} }
} }
} }
...@@ -526,7 +528,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) { ...@@ -526,7 +528,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) { for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) { for (auto& flag_bias : {false, true}) {
for (auto& flag_act: {0, 1}) { for (auto& flag_act : {0, 1}) {
for (auto& c : {1, 5, 15, 33}) { for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims; std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5}); DDim weights_dim({c, 1, 5, 5});
......
...@@ -95,8 +95,8 @@ bool test_gemm_int8(bool tra, ...@@ -95,8 +95,8 @@ bool test_gemm_int8(bool tra,
std::vector<float> scale_merge_int8(static_cast<size_t>(m)); std::vector<float> scale_merge_int8(static_cast<size_t>(m));
ActivationParam act_param; ActivationParam act_param;
act_param.has_active = has_relu; act_param.has_active = has_relu;
if (has_relu){ if (has_relu) {
act_param.active_type = (paddle::lite_api::ActivationType)1; act_param.active_type = (paddle::lite_api::ActivationType)1;
} }
for (int j = 0; j < m; ++j) { for (int j = 0; j < m; ++j) {
scale_merge_fp32[j] = scale_a[j] * scale_b[0]; scale_merge_fp32[j] = scale_a[j] * scale_b[0];
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "lite/core/profile/timer.h" #include "lite/core/profile/timer.h"
#include "lite/core/tensor.h" #include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h" #include "lite/tests/utils/tensor_utils.h"
#include "lite/backends/arm/math/saturate.h"
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
...@@ -99,8 +98,7 @@ bool test_gemv_int8(bool tra, ...@@ -99,8 +98,7 @@ bool test_gemv_int8(bool tra,
} }
LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n
<< ", transA: " << (tra ? "true" : "false") << ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false"); << ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
auto da = ta.mutable_data<int8_t>(); auto da = ta.mutable_data<int8_t>();
...@@ -155,7 +153,7 @@ bool test_gemv_int8(bool tra, ...@@ -155,7 +153,7 @@ bool test_gemv_int8(bool tra,
1, 1,
1, 1,
tc_basic_fp32.numel()); tc_basic_fp32.numel());
if (flag_act == 2) { // relu6 if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic_int8.numel(); i++) { for (int i = 0; i < tc_basic_int8.numel(); i++) {
dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i]; dc_basic_fp32[i] = dc_basic_fp32[i] > six ? six : dc_basic_fp32[i];
dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i]; dc_basic_int8[i] = dc_basic_int8[i] > six ? six : dc_basic_int8[i];
......
...@@ -108,7 +108,7 @@ bool test_sgemv(bool tra, ...@@ -108,7 +108,7 @@ bool test_sgemv(bool tra,
flag_act, flag_act,
six, six,
alpha); alpha);
if (flag_act == 2) { // relu6 if (flag_act == 2) { // relu6
for (int i = 0; i < tc_basic.numel(); i++) { for (int i = 0; i < tc_basic.numel(); i++) {
dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i]; dc_basic[i] = dc_basic[i] > six ? six : dc_basic[i];
} }
......
...@@ -301,7 +301,7 @@ static void conv_basic(const Dtype1* din, ...@@ -301,7 +301,7 @@ static void conv_basic(const Dtype1* din,
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx] ? dst_data_ref[out_idx]
: (Dtype2)0; : (Dtype2)0;
//dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six // dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
// ? dst_data_ref[out_idx] // ? dst_data_ref[out_idx]
// : (Dtype2)six; // : (Dtype2)six;
} else if (act_type == 4) { } else if (act_type == 4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册