提交 23571a48 编写于 作者: Z ZhenWang

add more UT for pool_int8.

上级 85b4921c
...@@ -31,7 +31,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -31,7 +31,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO()
#else #else
const int8_t *a_ptr, *b_ptr; const int8_t *a_ptr, *b_ptr;
a_ptr = a; a_ptr = a;
...@@ -249,7 +249,7 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -249,7 +249,7 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
#define PADDLE_LABEL_LOOP "1" #define PADDLE_LABEL_LOOP "1"
#define PADDLE_LABEL_AFTER_LOOP "2" #define PADDLE_LABEL_AFTER_LOOP "2"
...@@ -376,7 +376,7 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c, ...@@ -376,7 +376,7 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
const int8_t *a_ptr, *b_ptr; const int8_t *a_ptr, *b_ptr;
a_ptr = a; a_ptr = a;
...@@ -681,7 +681,7 @@ void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a, ...@@ -681,7 +681,7 @@ void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
...@@ -704,7 +704,7 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha, ...@@ -704,7 +704,7 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha,
for (int32_t j = 0; j < nc; j += NR_INT8) { for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) { for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC); // AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
...@@ -742,7 +742,7 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, ...@@ -742,7 +742,7 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) { for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t" "vld1.s8 {d0, d1}, [%[a0]]! \n\t"
...@@ -822,7 +822,7 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail, ...@@ -822,7 +822,7 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) { for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t" "vld1.s8 {d0, d1}, [%[a0]]! \n\t"
...@@ -1058,7 +1058,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B, ...@@ -1058,7 +1058,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
const int8_t *b0 = &B(i, j); const int8_t *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
// "pld [%[b0]] \n\t" // "pld [%[b0]] \n\t"
...@@ -1100,7 +1100,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C, ...@@ -1100,7 +1100,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) { int32_t ldc) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
int32_t nc1 = nc >> 4; int32_t nc1 = nc >> 4;
int32_t _nc1 = nc & 15; int32_t _nc1 = nc & 15;
...@@ -1164,7 +1164,7 @@ void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, ...@@ -1164,7 +1164,7 @@ void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) { int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
int32_t zero = 0; int32_t zero = 0;
int8_t narrow = -128; int8_t narrow = -128;
...@@ -1292,7 +1292,7 @@ void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C, ...@@ -1292,7 +1292,7 @@ void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) { int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
int32_t zero = 0; int32_t zero = 0;
int32_t nc1 = nc >> 3; int32_t nc1 = nc >> 3;
......
...@@ -60,7 +60,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -60,7 +60,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedB_int8 = static_cast<int8_t *>( packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC)); paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8); PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif #endif
...@@ -82,7 +82,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -82,7 +82,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedA_int8 = static_cast<int8_t *>( packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC)); paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8); PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif #endif
...@@ -106,7 +106,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -106,7 +106,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_A = packedA_int8 + MC * KC * local_threads; int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A); PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif #endif
...@@ -131,7 +131,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha, ...@@ -131,7 +131,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_B = packedB_int8 + KC * NC * local_threads; int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads; int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B); PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif #endif
...@@ -161,7 +161,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail, ...@@ -161,7 +161,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *b0 = &B(i, j); const int8_t *b0 = &B(i, j);
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
// "pld [%[b0]] \n\t" // "pld [%[b0]] \n\t"
...@@ -257,7 +257,7 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, ...@@ -257,7 +257,7 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) { for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t" "vld1.s8 {d0, d1}, [%[a0]]! \n\t"
...@@ -337,7 +337,7 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail, ...@@ -337,7 +337,7 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) { for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t" "vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......
...@@ -153,7 +153,7 @@ void Pool3x3Maxs1_int8(const Tensor *input, Tensor *output, int32_t pad_h, ...@@ -153,7 +153,7 @@ void Pool3x3Maxs1_int8(const Tensor *input, Tensor *output, int32_t pad_h,
int32_t nw1 = left_w >> 3; int32_t nw1 = left_w >> 3;
int32_t left_w1 = left_w & 0x7; int32_t left_w1 = left_w & 0x7;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
if (nw > 0) { if (nw > 0) {
#define LOOP_LABEL "1" #define LOOP_LABEL "1"
...@@ -334,7 +334,7 @@ void Pool3x3Maxs2_int8(const Tensor *input, Tensor *output, int32_t pad_h, ...@@ -334,7 +334,7 @@ void Pool3x3Maxs2_int8(const Tensor *input, Tensor *output, int32_t pad_h,
int32_t nw1 = left_w >> 3; int32_t nw1 = left_w >> 3;
int32_t left_w1 = left_w & 0x7; int32_t left_w1 = left_w & 0x7;
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
if (nw > 0) { if (nw > 0) {
#define LOOP_LABEL "1" #define LOOP_LABEL "1"
...@@ -527,7 +527,7 @@ void Pool3x3Max_int8(const vector<int> &strides, const vector<int> &paddings, ...@@ -527,7 +527,7 @@ void Pool3x3Max_int8(const vector<int> &strides, const vector<int> &paddings,
} else { } else {
#if __ARM_NEON #if __ARM_NEON
#if __aarch64__ #if __aarch64__
// TODO(wzzju) // TODO
#else #else
asm volatile( asm volatile(
"vld1.8 {d0}, [%[pos1]] \n\t" "vld1.8 {d0}, [%[pos1]] \n\t"
......
...@@ -244,58 +244,49 @@ int main(int argc, char *argv[]) { ...@@ -244,58 +244,49 @@ int main(int argc, char *argv[]) {
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2"; << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height, paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height,
in_width); in_width);
// // kernel = 3, pad = 3, stride = 3 // kernel = 3, pad = 3, stride = 3
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3";
// stride=3"; paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
// paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height, in_width);
// in_width); // kernel = 7, pad = 0, stride = 1
// // kernel = 7, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO)
// LOG(paddle_mobile::kLOG_INFO) << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
// stride=1"; in_width);
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height, // kernel = 7, pad = 0, stride = 2
// in_width); LOG(paddle_mobile::kLOG_INFO)
// // kernel = 7, pad = 0, stride = 2 << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2";
// LOG(paddle_mobile::kLOG_INFO) paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, in_width);
// stride=2"; // kernel = 7, pad = 0, stride = 3
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height, LOG(paddle_mobile::kLOG_INFO)
// in_width); << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3";
// // kernel = 7, pad = 0, stride = 3 paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
// LOG(paddle_mobile::kLOG_INFO) in_width);
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, // kernel = 3, pad = 0, stride = 1
// stride=3"; LOG(paddle_mobile::kLOG_INFO)
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height, << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1";
// in_width); paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
// // kernel = 3, pad = 0, stride = 1 in_width);
// LOG(paddle_mobile::kLOG_INFO) // kernel = 3, pad = 0, stride = 3
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, LOG(paddle_mobile::kLOG_INFO)
// stride=1"; << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3";
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height, paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
// in_width); in_width);
// // kernel = 3, pad = 0, stride = 3 // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO) LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
// stride=3"; paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height, in_width);
// in_width); // kernel = 7, pad = 0, stride = 4
// // kernel = 7, pad = 0, stride = 1 LOG(paddle_mobile::kLOG_INFO)
// LOG(paddle_mobile::kLOG_INFO) << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4";
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
// stride=1"; in_width);
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height, // kernel = 5, pad = 0, stride = 1
// in_width); LOG(paddle_mobile::kLOG_INFO)
// // kernel = 7, pad = 0, stride = 4 << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
// LOG(paddle_mobile::kLOG_INFO) paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, in_width);
// stride=4";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
// in_width);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
// in_width);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册