提交 23571a48 编写于 作者: Z ZhenWang

add more UT for pool_int8.

上级 85b4921c
......@@ -31,7 +31,7 @@ void Gemm::AddDot4x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO()
#else
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
......@@ -249,7 +249,7 @@ void Gemm::AddDot4x2(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
#define PADDLE_LABEL_LOOP "1"
#define PADDLE_LABEL_AFTER_LOOP "2"
......@@ -376,7 +376,7 @@ void Gemm::AddDot6x8(int32_t k, const int8_t *a, const int8_t *b, int32_t *c,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
const int8_t *a_ptr, *b_ptr;
a_ptr = a;
......@@ -681,7 +681,7 @@ void Gemm::InnerKernel(int32_t mc, int32_t nc, float alpha, const int8_t *a,
for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__
// TODO(wzzju)
// TODO
#else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
......@@ -704,7 +704,7 @@ void Gemm::InnerKernelWithBias(int32_t mc, int32_t nc, float alpha,
for (int32_t j = 0; j < nc; j += NR_INT8) {
for (int32_t i = 0; i < mc; i += MR_INT8) {
#if __aarch64__
// TODO(wzzju)
// TODO
#else
// AddDot6x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
// AddDot4x8(KC, a + i * KC, b + j * KC, c + i * NC + j, NC);
......@@ -742,7 +742,7 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -822,7 +822,7 @@ void Gemm::PackMatrixA_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -1058,7 +1058,7 @@ void Gemm::PackMatrixB_8c(int32_t k, int32_t n, int32_t n_tail, const int8_t *B,
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
// "pld [%[b0]] \n\t"
......@@ -1100,7 +1100,7 @@ void Gemm::WriteBasic(int32_t mc, int32_t nc, int32_t *c, int32_t *C,
int32_t ldc) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
int32_t nc1 = nc >> 4;
int32_t _nc1 = nc & 15;
......@@ -1164,7 +1164,7 @@ void Gemm::WriteWithAddScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
int32_t zero = 0;
int8_t narrow = -128;
......@@ -1292,7 +1292,7 @@ void Gemm::WriteWithAddReluScale(int32_t mc, int32_t nc, int32_t *c, int8_t *C,
int32_t ldc, int32_t *bias, float scale) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
int32_t zero = 0;
int32_t nc1 = nc >> 3;
......
......@@ -60,7 +60,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedB_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * KC * NC));
#if __aarch64__
// TODO(wzzju)
// TODO
#else
PackMatrixB_omp_2c_16(k, n, n % NR_INT8, B, ldb, packedB_int8);
#endif
......@@ -82,7 +82,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
packedA_int8 = static_cast<int8_t *>(
paddle_mobile::memory::Alloc(sizeof(int8_t) * MC * KC));
#if __aarch64__
// TODO(wzzju)
// TODO
#else
PackMatrixA_omp_4r_16(m, k, m % MR_INT8, A, lda, packedA_int8);
#endif
......@@ -106,7 +106,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_A = packedA_int8 + MC * KC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
// TODO
#else
PackMatrixA_4r_16(mc, k, mc % MR_INT8, &A(i, 0), lda, local_A);
#endif
......@@ -131,7 +131,7 @@ void Gemm::Sgemm_omp(int32_t m, int32_t n, int32_t k, float alpha,
int8_t *local_B = packedB_int8 + KC * NC * local_threads;
int32_t *local_C = packedC_int32 + MC * NC * local_threads;
#if __aarch64__
// TODO(wzzju)
// TODO
#else
PackMatrixB_2c_16(k, nc, nc % NR_INT8, &B(0, j), ldb, local_B);
#endif
......@@ -161,7 +161,7 @@ void Gemm::PackMatrixB_omp_8c(int32_t k, int32_t n, int32_t n_tail,
const int8_t *b0 = &B(i, j);
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
// "pld [%[b0]] \n\t"
......@@ -257,7 +257,7 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......@@ -337,7 +337,7 @@ void Gemm::PackMatrixA_omp_4r_16(int32_t m, int32_t k, int32_t m_tail,
for (int32_t j = 0; j < k_count; ++j) {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
"vld1.s8 {d0, d1}, [%[a0]]! \n\t"
......
......@@ -153,7 +153,7 @@ void Pool3x3Maxs1_int8(const Tensor *input, Tensor *output, int32_t pad_h,
int32_t nw1 = left_w >> 3;
int32_t left_w1 = left_w & 0x7;
#if __aarch64__
// TODO(wzzju)
// TODO
#else
if (nw > 0) {
#define LOOP_LABEL "1"
......@@ -334,7 +334,7 @@ void Pool3x3Maxs2_int8(const Tensor *input, Tensor *output, int32_t pad_h,
int32_t nw1 = left_w >> 3;
int32_t left_w1 = left_w & 0x7;
#if __aarch64__
// TODO(wzzju)
// TODO
#else
if (nw > 0) {
#define LOOP_LABEL "1"
......@@ -527,7 +527,7 @@ void Pool3x3Max_int8(const vector<int> &strides, const vector<int> &paddings,
} else {
#if __ARM_NEON
#if __aarch64__
// TODO(wzzju)
// TODO
#else
asm volatile(
"vld1.8 {d0}, [%[pos1]] \n\t"
......
......@@ -244,58 +244,49 @@ int main(int argc, char *argv[]) {
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=2, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 2, 2>(in_channels, in_height,
in_width);
// // kernel = 3, pad = 3, stride = 3
// LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3,
// stride=3";
// paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 2
// LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=2";
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 3
// LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=3";
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
// in_width);
// // kernel = 3, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 3, pad = 0, stride = 3
// LOG(paddle_mobile::kLOG_INFO)
// << "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0,
// stride=3";
// paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
// in_width);
// // kernel = 7, pad = 0, stride = 4
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0,
// stride=4";
// paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
// in_width);
// // kernel = 5, pad = 0, stride = 1
// LOG(paddle_mobile::kLOG_INFO)
// << "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0,
// stride=1";
// paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
// in_width);
// kernel = 3, pad = 3, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=max, kernel=3, pad=3, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 0, 3, 3, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 2
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=2";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 2>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 7, 0, 3>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=1";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 1>(in_channels, in_height,
in_width);
// kernel = 3, pad = 0, stride = 3
LOG(paddle_mobile::kLOG_INFO)
<< "int8_t, ceil_mode=false, pooling_type=avg, kernel=3, pad=0, stride=3";
paddle_mobile::TestPoolOp<int8_t, 0, 1, 3, 0, 3>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 1>(in_channels, in_height,
in_width);
// kernel = 7, pad = 0, stride = 4
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=7, pad=0, stride=4";
paddle_mobile::TestPoolOp<float, 0, 1, 7, 0, 4>(in_channels, in_height,
in_width);
// kernel = 5, pad = 0, stride = 1
LOG(paddle_mobile::kLOG_INFO)
<< "float, ceil_mode=false, pooling_type=avg, kernel=5, pad=0, stride=1";
paddle_mobile::TestPoolOp<float, 0, 1, 5, 0, 1>(in_channels, in_height,
in_width);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册