提交 e62ea9b1 编写于 作者: 李寅

Merge branch 'gemmlowp' into 'master'

Fix incorrect throughput stats for gemmlowp

See merge request !745
...@@ -18,9 +18,10 @@ ...@@ -18,9 +18,10 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "public/gemmlowp.h"
#include "mace/core/testing/test_benchmark.h" #include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/gemm.h" #include "mace/kernels/gemm.h"
#include "public/gemmlowp.h" #include "mace/kernels/gemmlowp_util.h"
namespace gemmlowp { namespace gemmlowp {
...@@ -140,8 +141,7 @@ void MatmulBenchmark_gemmlowp_uint8(int iters, int rows, int depth, int cols) { ...@@ -140,8 +141,7 @@ void MatmulBenchmark_gemmlowp_uint8(int iters, int rows, int depth, int cols) {
const auto output_pipeline = const auto output_pipeline =
std::make_tuple(quantize_down_stage, saturating_cast_stage); std::make_tuple(quantize_down_stage, saturating_cast_stage);
gemmlowp::GemmContext gemm_context; gemmlowp::GemmContext& gemm_context = GetGemmlowpContext();
gemm_context.set_max_num_threads(4);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::uint8_t, BitDepthParams>( gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::uint8_t, BitDepthParams>(
...@@ -172,8 +172,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { ...@@ -172,8 +172,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
const auto output_pipeline = std::make_tuple(); const auto output_pipeline = std::make_tuple();
gemmlowp::GemmContext gemm_context; gemmlowp::GemmContext& gemm_context = GetGemmlowpContext();
gemm_context.set_max_num_threads(4);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, BitDepthParams>( gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, BitDepthParams>(
...@@ -191,21 +190,21 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { ...@@ -191,21 +190,21 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
} // namespace } // namespace
#define MACE_BM_MATMUL_FUNC(M, K, N, FUNC) \ #define MACE_BM_MATMUL_FUNC(M, K, N, FUNC, TYPE) \
static void MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC(int iters) { \ static void MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC(int iters) { \
const int64_t macc = static_cast<int64_t>(iters) * M * K * N; \ const int64_t macc = static_cast<int64_t>(iters) * M * K * N; \
const int64_t tot = static_cast<int64_t>(iters) * (M + N) * K; \ const int64_t tot = static_cast<int64_t>(iters) * (M + N) * K; \
mace::testing::MaccProcessed(macc); \ mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot * sizeof(float)); \ mace::testing::BytesProcessed(tot * sizeof(TYPE)); \
MatmulBenchmark_##FUNC(iters, M, K, N); \ MatmulBenchmark_##FUNC(iters, M, K, N); \
} \ } \
MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC) MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC)
#define MACE_BM_MATMUL(M, K, N) \ #define MACE_BM_MATMUL(M, K, N) \
MACE_BM_MATMUL_FUNC(M, K, N, Mace); \ MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen); \ MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8); \ MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32); MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t);
// Embedding size 384 // Embedding size 384
MACE_BM_MATMUL(7, 384, 384); MACE_BM_MATMUL(7, 384, 384);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册