diff --git a/WORKSPACE b/WORKSPACE index 1176c1ba39895a74f1571bb64a0b5f62a216fbce..932d487092213cc1b0847436093e4589df116c08 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -78,11 +78,10 @@ new_http_archive( http_archive( name = "gemmlowp", - sha256 = "b87faa7294dfcc5d678f22a59d2c01ca94ea1e2a3b488c38a95a67889ed0a658", - strip_prefix = "gemmlowp-38ebac7b059e84692f53e5938f97a9943c120d98", + sha256 = "5941b50afb7f43f96a2afaa101e024b5a1c6b0b4e4f110688fefa083bbdd652d", + strip_prefix = "gemmlowp-master-3559cf6e2a21a15b5bd8133bb632da6050aa8b8d", urls = [ - "http://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/38ebac7b059e84692f53e5938f97a9943c120d98.zip", - "https://github.com/google/gemmlowp/archive/38ebac7b059e84692f53e5938f97a9943c120d98.zip", + "https://cnbj1.fds.api.xiaomi.com/mace/third-party/gemmlowp/gemmlowp-master-3559cf6e2a21a15b5bd8133bb632da6050aa8b8d.zip", ], ) diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index efd6f8e88a02c5e602c8af91e89b789d1f3985d6..a8991f472af6d90da8c9b3e499a0572873826025 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -70,6 +70,7 @@ cc_library( deps = [ "//mace/core", "//mace/utils", + "@gemmlowp", ], ) @@ -104,7 +105,7 @@ cc_test( deps = [ ":kernels", "//mace/ops", - "@gtest//:gtest", + "@gtest", "@gtest//:gtest_main", ], ) @@ -133,8 +134,9 @@ cc_test( linkstatic = 1, deps = [ ":kernels", - "//mace/ops", "//mace/core:test_benchmark_main", + "//mace/ops", "//third_party/eigen3", + "@gemmlowp", ], ) diff --git a/mace/kernels/matmul_benchmark.cc b/mace/kernels/matmul_benchmark.cc index 32ab8b4b66b459f34d54e9114046c1b0828e9f29..fdf9962604d1571cca37144e9fbd140fac8ea51d 100644 --- a/mace/kernels/matmul_benchmark.cc +++ b/mace/kernels/matmul_benchmark.cc @@ -15,10 +15,74 @@ #include #include #include +#include #include #include "mace/core/testing/test_benchmark.h" #include "mace/kernels/gemm.h" +#include "public/gemmlowp.h" + +namespace gemmlowp { + +template +class Matrix : public MatrixMap { + public: + typedef MatrixMap Map; + typedef MatrixMap ConstMap; + typedef typename Map::Scalar Scalar; + static const MapOrder Order = tOrder; + using Map::cols_; + using Map::data_; + using Map::kOrder; + using Map::rows_; + using Map::stride_; + + public: + Matrix() : Map(nullptr, 0, 0, 0) {} + + Matrix(int rows, int cols) : Map(nullptr, 0, 0, 0) { Resize(rows, cols); } + + Matrix(const Matrix &other) : Map(nullptr, 0, 0, 0) { *this = other; } + + Matrix &operator=(const Matrix &other) { + Resize(other.rows_, other.cols_); + std::memcpy(data_, other.data_, size() * sizeof(Scalar)); + return *this; + } + + friend bool operator==(const Matrix &a, const Matrix &b) { + return a.rows_ == b.rows_ && a.cols_ == b.cols_ && + !std::memcmp(a.data_, b.data_, a.size()); + } + + void Resize(int rows, int cols) { + rows_ = rows; + cols_ = cols; + stride_ = kOrder == gemmlowp::MapOrder::ColMajor ? rows : cols; + storage.resize(size()); + data_ = storage.data(); + } + + int size() const { return rows_ * cols_; } + + Map &map() { return *static_cast(this); } + + ConstMap const_map() const { return ConstMap(data_, rows_, cols_, stride_); } + + protected: + std::vector storage; +}; + +template +void MakeZero(MatrixType *m) { + for (int c = 0; c < m->cols(); c++) { + for (int r = 0; r < m->rows(); r++) { + (*m)(r, c) = 128; + } + } +} + +} // namespace gemmlowp namespace mace { namespace kernels { @@ -55,6 +119,76 @@ void MatmulBenchmark_Eigen(int iters, int m, int k, int n) { } } +void MatmulBenchmark_gemmlowp_uint8(int iters, int rows, int depth, int cols) { + mace::testing::StopTiming(); + + gemmlowp::Matrix lhs; + gemmlowp::Matrix rhs; + gemmlowp::Matrix result; + lhs.Resize(rows, depth); + rhs.Resize(depth, cols); + result.Resize(rows, cols); + gemmlowp::MakeZero(&lhs); + gemmlowp::MakeZero(&rhs); + gemmlowp::MakeZero(&result); + + gemmlowp::OutputStageQuantizeDownInt32ByFixedPoint quantize_down_stage; + quantize_down_stage.result_offset_after_shift = 128; + quantize_down_stage.result_fixedpoint_multiplier = 1234567890; + quantize_down_stage.result_shift = 16; + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + const auto output_pipeline = + std::make_tuple(quantize_down_stage, saturating_cast_stage); + + gemmlowp::GemmContext gemm_context; + gemm_context.set_max_num_threads(4); + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + + gemmlowp::GemmWithOutputPipeline( + &gemm_context, lhs.const_map(), rhs.const_map(), &result.map(), -128, + -128, output_pipeline); + + mace::testing::StartTiming(); + while (iters--) { + gemmlowp::GemmWithOutputPipeline( + &gemm_context, lhs.const_map(), rhs.const_map(), &result.map(), -128, + -128, output_pipeline); + } +} + +void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { + mace::testing::StopTiming(); + + gemmlowp::Matrix lhs; + gemmlowp::Matrix rhs; + gemmlowp::Matrix result; + lhs.Resize(rows, depth); + rhs.Resize(depth, cols); + result.Resize(rows, cols); + gemmlowp::MakeZero(&lhs); + gemmlowp::MakeZero(&rhs); + gemmlowp::MakeZero(&result); + + const auto output_pipeline = std::make_tuple(); + + gemmlowp::GemmContext gemm_context; + gemm_context.set_max_num_threads(4); + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + + gemmlowp::GemmWithOutputPipeline( + &gemm_context, lhs.const_map(), rhs.const_map(), &result.map(), -128, + -128, output_pipeline); + + mace::testing::StartTiming(); + while (iters--) { + gemmlowp::GemmWithOutputPipeline( + &gemm_context, lhs.const_map(), rhs.const_map(), &result.map(), -128, + -128, output_pipeline); + } +} + } // namespace #define MACE_BM_MATMUL_FUNC(M, K, N, FUNC) \ @@ -67,9 +201,11 @@ void MatmulBenchmark_Eigen(int iters, int m, int k, int n) { } \ MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC) -#define MACE_BM_MATMUL(M, K, N) \ - MACE_BM_MATMUL_FUNC(M, K, N, Mace); \ - MACE_BM_MATMUL_FUNC(M, K, N, Eigen); +#define MACE_BM_MATMUL(M, K, N) \ + MACE_BM_MATMUL_FUNC(M, K, N, Mace); \ + MACE_BM_MATMUL_FUNC(M, K, N, Eigen); \ + MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8); \ + MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32); // Embedding size 384 MACE_BM_MATMUL(7, 384, 384);