diff --git a/mace/kernels/matmul_benchmark.cc b/mace/kernels/matmul_benchmark.cc index fdf9962604d1571cca37144e9fbd140fac8ea51d..1771b2bb8903a65533ab1010e4b198285d6f5fd3 100644 --- a/mace/kernels/matmul_benchmark.cc +++ b/mace/kernels/matmul_benchmark.cc @@ -18,9 +18,10 @@ #include #include +#include "public/gemmlowp.h" #include "mace/core/testing/test_benchmark.h" #include "mace/kernels/gemm.h" -#include "public/gemmlowp.h" +#include "mace/kernels/gemmlowp_util.h" namespace gemmlowp { @@ -140,8 +141,7 @@ void MatmulBenchmark_gemmlowp_uint8(int iters, int rows, int depth, int cols) { const auto output_pipeline = std::make_tuple(quantize_down_stage, saturating_cast_stage); - gemmlowp::GemmContext gemm_context; - gemm_context.set_max_num_threads(4); + gemmlowp::GemmContext& gemm_context = GetGemmlowpContext(); using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; gemmlowp::GemmWithOutputPipeline( @@ -172,8 +172,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { const auto output_pipeline = std::make_tuple(); - gemmlowp::GemmContext gemm_context; - gemm_context.set_max_num_threads(4); + gemmlowp::GemmContext& gemm_context = GetGemmlowpContext(); using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; gemmlowp::GemmWithOutputPipeline( @@ -191,21 +190,21 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { } // namespace -#define MACE_BM_MATMUL_FUNC(M, K, N, FUNC) \ +#define MACE_BM_MATMUL_FUNC(M, K, N, FUNC, TYPE) \ static void MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC(int iters) { \ const int64_t macc = static_cast(iters) * M * K * N; \ const int64_t tot = static_cast(iters) * (M + N) * K; \ mace::testing::MaccProcessed(macc); \ - mace::testing::BytesProcessed(tot * sizeof(float)); \ + mace::testing::BytesProcessed(tot * sizeof(TYPE)); \ MatmulBenchmark_##FUNC(iters, M, K, N); \ } \ MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC) -#define MACE_BM_MATMUL(M, K, N) \ - MACE_BM_MATMUL_FUNC(M, K, N, Mace); \ - MACE_BM_MATMUL_FUNC(M, K, N, Eigen); \ - MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8); \ - MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32); +#define MACE_BM_MATMUL(M, K, N) \ + MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \ + MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \ + MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \ + MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t); // Embedding size 384 MACE_BM_MATMUL(7, 384, 384);