提交 e5066279 编写于 作者: C chenjiaoAngel

fix format. test=develop

上级 3aec2316
......@@ -4222,15 +4222,48 @@ void gemm_prepack_int8(const int8_t* A_packed,
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) {
gemm_prepack_sdot_int8<float32_t>(
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
gemm_prepack_sdot_int8<float32_t>(A_packed,
B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} else {
gemm_prepack_oth_int8<float32_t>(
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
gemm_prepack_oth_int8<float32_t>(A_packed,
B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
}
#else
gemm_prepack_oth_int8<float32_t>(
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
gemm_prepack_oth_int8<float32_t>(A_packed,
B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
#endif
}
......@@ -4271,11 +4304,33 @@ void gemm_prepack_int8(const int8_t* A_packed,
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if (ctx->has_dot()) {
gemm_prepack_sdot_int8<int8_t>(
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
gemm_prepack_sdot_int8<int8_t>(A_packed,
B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
} else {
gemm_prepack_oth_int8<int8_t>(
A_packed, B, bias, C, M, N, K, is_bias, flag_act, is_transB, scale, alpha, ctx);
gemm_prepack_oth_int8<int8_t>(A_packed,
B,
bias,
C,
M,
N,
K,
is_bias,
flag_act,
is_transB,
scale,
alpha,
ctx);
}
#else
gemm_prepack_oth_int8<int8_t>(
......
......@@ -311,7 +311,8 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[7] += ptr_in[i] * ptr_w7[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
}
//! deal with remains
......@@ -355,7 +356,8 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
}
#else // __aarch64__
int out_cnt = M >> 2;
......@@ -449,7 +451,8 @@ bool gemv_int8_oth(const int8_t* A,
ptr_out[2] += ptr_in[i] * ptr_w2[i];
ptr_out[3] += ptr_in[i] * ptr_w3[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 4, flag_act, act, six, alpha);
write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 4, flag_act, act, six, alpha);
}
//! deal with remains
#pragma omp parallel for
......@@ -490,7 +493,8 @@ bool gemv_int8_oth(const int8_t* A,
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
}
#endif // __aarch64__
return true;
......@@ -648,7 +652,8 @@ bool gemv_int8_sdot(const int8_t* A,
ptr_out[6] += ptr_in[i] * ptr_w6[i];
ptr_out[7] += ptr_in[i] * ptr_w7[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 8, flag_act, act, six, alpha);
}
//! deal with remains
#pragma omp parallel for
......@@ -688,7 +693,8 @@ bool gemv_int8_sdot(const int8_t* A,
for (int i = 0; i < tail; ++i) {
ptr_out[0] += ptr_in[i] * ptr_w0[i];
}
write_gemv_out(ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
write_gemv_out(
ptr_out, out_ptr, scale_ptr, bias_ptr, 1, flag_act, act, six, alpha);
}
return true;
}
......
......@@ -323,8 +323,10 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
dim_out.production());
if (flag_act == 2) { // relu6
for (int i = 0; i < dim_out.production(); i++) {
dout_basic_int8[i] = dout_basic_int8[i] > six ? six : dout_basic_int8[i];
dout_basic_fp32[i] = dout_basic_fp32[i] > six ? six : dout_basic_fp32[i];
dout_basic_int8[i] =
dout_basic_int8[i] > six ? six : dout_basic_int8[i];
dout_basic_fp32[i] =
dout_basic_fp32[i] > six ? six : dout_basic_fp32[i];
}
}
}
......@@ -526,7 +528,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2, 3, 4}) {
for (auto& flag_bias : {false, true}) {
for (auto& flag_act: {0, 1}) {
for (auto& flag_act : {0, 1}) {
for (auto& c : {1, 5, 15, 33}) {
std::vector<DDim> dims;
DDim weights_dim({c, 1, 5, 5});
......
......@@ -95,7 +95,7 @@ bool test_gemm_int8(bool tra,
std::vector<float> scale_merge_int8(static_cast<size_t>(m));
ActivationParam act_param;
act_param.has_active = has_relu;
if (has_relu){
if (has_relu) {
act_param.active_type = (paddle::lite_api::ActivationType)1;
}
for (int j = 0; j < m; ++j) {
......
......@@ -23,7 +23,6 @@
#include "lite/core/profile/timer.h"
#include "lite/core/tensor.h"
#include "lite/tests/utils/tensor_utils.h"
#include "lite/backends/arm/math/saturate.h"
typedef paddle::lite::Tensor Tensor;
using paddle::lite::profile::Timer;
......@@ -99,8 +98,7 @@ bool test_gemv_int8(bool tra,
}
LOG(INFO) << "gemv_int8 M: " << m << ", N: " << n
<< ", transA: " << (tra ? "true" : "false")
<< ", act: " << flag_act
<< ", transA: " << (tra ? "true" : "false") << ", act: " << flag_act
<< ", bias: " << (has_bias ? "true" : "false");
#ifdef LITE_WITH_ARM
auto da = ta.mutable_data<int8_t>();
......
......@@ -301,7 +301,7 @@ static void conv_basic(const Dtype1* din,
dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0
? dst_data_ref[out_idx]
: (Dtype2)0;
//dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
// dst_data_ref[out_idx] = dst_data_ref[out_idx] < (Dtype2)six
// ? dst_data_ref[out_idx]
// : (Dtype2)six;
} else if (act_type == 4) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册