From 9577c8cbce50e4d3e3c3668378681c904becb229 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Mon, 27 Aug 2018 17:13:45 +0800 Subject: [PATCH] Add quantized matmul --- mace/kernels/gemmlowp_util.h | 21 ++++++ mace/kernels/matmul.h | 96 +++++++++++++++++++++++++++ mace/kernels/matmul_benchmark.cc | 6 +- mace/ops/matmul.cc | 6 ++ mace/ops/matmul_benchmark.cc | 42 +++++++++--- mace/ops/matmul_test.cc | 110 +++++++++++++++++++++++++++++++ mace/public/mace_runtime.h | 36 ++++++---- 7 files changed, 291 insertions(+), 26 deletions(-) diff --git a/mace/kernels/gemmlowp_util.h b/mace/kernels/gemmlowp_util.h index 8ce20d38..28d45d3a 100644 --- a/mace/kernels/gemmlowp_util.h +++ b/mace/kernels/gemmlowp_util.h @@ -31,6 +31,10 @@ struct GemmlowpOutputPipeline { gemmlowp::OutputStageBiasAddition, gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; + typedef std::tuple< + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageSaturatingCastToUint8> NoBiasPipeline; + static Pipeline Make( const int32_t *bias_data, const index_t channels, const float lhs_scale, const float rhs_scale, const float output_scale, @@ -52,6 +56,23 @@ struct GemmlowpOutputPipeline { return std::make_tuple(bias_addition_stage, quantize_down_stage, saturating_cast_stage); } + + static NoBiasPipeline MakeNoBias( + const float lhs_scale, const float rhs_scale, const float output_scale, + const int32_t output_zero_point) { + int32_t quantized_multiplier; + int32_t right_shift; + kernels::GetOutputMultiplierAndShift(lhs_scale, rhs_scale, output_scale, + &quantized_multiplier, &right_shift); + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_zero_point; + quantize_down_stage.result_fixedpoint_multiplier = quantized_multiplier; + quantize_down_stage.result_shift = right_shift; + + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(quantize_down_stage, saturating_cast_stage); + } }; } // namespace mace diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index 7b54aa1b..736fdd34 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -30,6 +30,7 @@ #include "mace/core/tensor.h" #include "mace/kernels/gemm.h" #include "mace/utils/utils.h" +#include "mace/kernels/gemmlowp_util.h" #ifdef MACE_ENABLE_OPENCL #include "mace/core/runtime/opencl/cl2_header.h" @@ -92,6 +93,101 @@ struct MatMulFunctor { } }; +template <> +struct MatMulFunctor { + template + void MatMulImpl(const Tensor *A, + const Tensor *B, + const index_t height, + const index_t K, + const index_t width, + Tensor *C) { + gemmlowp::GemmContext& gemm_context = GetGemmlowpContext(); + + Tensor::MappingGuard guarda(A); + Tensor::MappingGuard guardb(B); + Tensor::MappingGuard guardc(C); + auto a_ptr_base = A->data(); + auto b_ptr_base = B->data(); + auto c_ptr_base = C->mutable_data(); + index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, + std::multiplies()); + index_t a_size = height * K; + index_t b_size = K * width; + index_t c_size = height * width; + + const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias( + A->scale(), B->scale(), C->scale(), C->zero_point()); + + for (index_t i = 0; i < batch; ++i) { + gemmlowp::MatrixMap + a_matrix(a_ptr_base + i * a_size, height, K); + gemmlowp::MatrixMap + b_matrix(b_ptr_base + i * b_size, K, width); + gemmlowp::MatrixMap + c_matrix(c_ptr_base + i * c_size, height, width); + + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + gemmlowp::GemmWithOutputPipeline( + &gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), + -B->zero_point(), output_pipeline); + } + } + + MaceStatus operator()(const Tensor *A, + const Tensor *B, + Tensor *C, + bool transpose_a, + bool transpose_b, + StatsFuture *future) { + MACE_UNUSED(future); + + index_t rank = A->dim_size(); + index_t height = A->dim(rank - 2); + index_t K = A->dim(rank - 1); + index_t width; + + if (transpose_a) { + std::swap(height, K); + } + if (transpose_b) { + width = B->dim(rank - 2); + } else { + width = B->dim(rank - 1); + } + + std::vector c_shape = A->shape(); + c_shape[rank - 2] = height; + c_shape[rank - 1] = width; + + MACE_RETURN_IF_ERROR(C->Resize(c_shape)); + + constexpr gemmlowp::MapOrder kRowMajor = gemmlowp::MapOrder::RowMajor; + constexpr gemmlowp::MapOrder kColMajor = gemmlowp::MapOrder::ColMajor; + +#define MATMUL_IMPL(AOrder, BOrder) \ + MatMulImpl(A, B, height, K, width, C); + + if (transpose_a) { + if (transpose_b) { + MATMUL_IMPL(kColMajor, kColMajor); + } else { + MATMUL_IMPL(kColMajor, kRowMajor); + } + } else { + if (transpose_b) { + MATMUL_IMPL(kRowMajor, kColMajor); + } else { + MATMUL_IMPL(kRowMajor, kRowMajor); + } + } + +#undef MATMUL_IMPL + + return MACE_SUCCESS; + } +}; + #ifdef MACE_ENABLE_OPENCL template struct MatMulFunctor { diff --git a/mace/kernels/matmul_benchmark.cc b/mace/kernels/matmul_benchmark.cc index 1771b2bb..d109de47 100644 --- a/mace/kernels/matmul_benchmark.cc +++ b/mace/kernels/matmul_benchmark.cc @@ -109,9 +109,9 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) { void MatmulBenchmark_Eigen(int iters, int m, int k, int n) { mace::testing::StopTiming(); - Eigen::MatrixXd lhs = Eigen::MatrixXd::Random(m, k); - Eigen::MatrixXd rhs = Eigen::MatrixXd::Random(k, n); - Eigen::MatrixXd result = Eigen::MatrixXd::Zero(m, n); + Eigen::MatrixXf lhs = Eigen::MatrixXf::Random(m, k); + Eigen::MatrixXf rhs = Eigen::MatrixXf::Random(k, n); + Eigen::MatrixXf result = Eigen::MatrixXf::Zero(m, n); // warm up result = lhs * rhs; mace::testing::StartTiming(); diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index e1c5932c..ca0b68e5 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -24,6 +24,12 @@ void Register_MatMul(OperatorRegistryBase *op_registry) { .Build(), MatMulOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + MatMulOp); + #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("MatMul") .Device(DeviceType::GPU) diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index 3e3327d8..146f8d1c 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -31,8 +31,12 @@ void MatMulBenchmark( OpsTestNet net; // Add input data - net.AddRandomInput("A", {batch, height, channels}); - net.AddRandomInput("B", {batch, channels, out_width}); + net.AddRandomInput("A", {batch, height, channels}); + net.AddRandomInput("B", {batch, channels, out_width}); + if (DataTypeToEnum::value == DT_UINT8) { + net.GetTensor("A")->SetScale(0.1); + net.GetTensor("B")->SetScale(0.1); + } if (D == DeviceType::GPU) { BufferToImage(&net, "A", "AImage", kernels::BufferType::IN_OUT_WIDTH); @@ -50,12 +54,18 @@ void MatMulBenchmark( .Input("A") .Input("B") .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } + net.Setup(D); + if (DataTypeToEnum::value == DT_UINT8) { + net.GetTensor("Output")->SetScale(0.1); + } + // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); + for (int i = 0; i < 2; ++i) { + net.Run(); } net.Sync(); @@ -74,8 +84,12 @@ void MatMulTransposeBenchmark( OpsTestNet net; // Add input data - net.AddRandomInput("A", {batch, height, channels}); - net.AddRandomInput("B", {batch, out_width, channels}); + net.AddRandomInput("A", {batch, height, channels}); + net.AddRandomInput("B", {batch, out_width, channels}); + if (DataTypeToEnum::value == DT_UINT8) { + net.GetTensor("A")->SetScale(0.1); + net.GetTensor("B")->SetScale(0.1); + } if (D == DeviceType::CPU) { OpDefBuilder("MatMul", "MatMulBM") @@ -83,14 +97,20 @@ void MatMulTransposeBenchmark( .Input("B") .AddIntArg("transpose_b", 1) .Output("Output") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); } else { MACE_NOT_IMPLEMENTED; } + net.Setup(D); + if (DataTypeToEnum::value == DT_UINT8) { + net.GetTensor("Output")->SetScale(0.1); + } + // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); + for (int i = 0; i < 2; ++i) { + net.Run(); } net.Sync(); @@ -116,7 +136,8 @@ void MatMulTransposeBenchmark( #define MACE_BM_MATMUL(N, H, C, W) \ MACE_BM_MATMUL_MACRO(N, H, C, W, float, CPU); \ MACE_BM_MATMUL_MACRO(N, H, C, W, float, GPU); \ - MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU); + MACE_BM_MATMUL_MACRO(N, H, C, W, half, GPU); \ + MACE_BM_MATMUL_MACRO(N, H, C, W, uint8_t, CPU); #define MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, TYPE, DEVICE) \ static void MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE( \ @@ -130,7 +151,8 @@ void MatMulTransposeBenchmark( MACE_BENCHMARK(MACE_BM_MATMUL_##T_##N##_##H##_##C##_##W##_##TYPE##_##DEVICE) #define MACE_BM_MATMUL_TRANPOSE(N, H, C, W) \ - MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); + MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \ + MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); MACE_BM_MATMUL(16, 32, 128, 49); MACE_BM_MATMUL(16, 32, 128, 961); diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index 397b00fe..18a9ddc8 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -213,6 +213,116 @@ TEST_F(MatMulOpTest, OPENCLHalfUnAlignedWithBatch) { Complex({2, 3}, 31, 61, 67); } +namespace { +void Quant(const std::vector &batch, + const index_t height, + const index_t channels, + const index_t out_width, + const bool transpose_a, + const bool transpose_b) { + // Construct graph + OpsTestNet net; + + // Add input data + index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, + std::multiplies()); + if (transpose_a) { + net.AddRandomInput("A", {batch_count, channels, height}); + } else { + net.AddRandomInput("A", {batch_count, height, channels}); + } + if (transpose_b) { + net.AddRandomInput("B", {batch_count, out_width, channels}); + } else { + net.AddRandomInput("B", {batch_count, channels, out_width}); + } + + OpDefBuilder("MatMul", "MatMulTest") + .Input("A") + .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .Input("B") + .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .Output("Output") + .AddIntArg("T", DT_FLOAT) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + OpDefBuilder("Quantize", "QuantizeA") + .Input("A") + .Output("QuantizedA") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeB") + .Input("B") + .Output("QuantizedB") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeOutput") + .Input("Output") + .Output("ExpectedQuantizedOutput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("MatMul", "QuantizeMatMulTest") + .Input("QuantizedA") + .AddIntArg("transpose_a", transpose_a ? 1 : 0) + .Input("QuantizedB") + .AddIntArg("transpose_b", transpose_b ? 1 : 0) + .Output("QuantizedOutput") + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.Setup(DeviceType::CPU); + Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput"); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(eq_output->scale()); + q_output->SetZeroPoint(eq_output->zero_point()); + net.Run(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(MatMulOpTest, Quant) { + Quant({1}, 64, 128, 32, false, false); + Quant({1}, 64, 32, 128, false, false); + Quant({2, 3}, 64, 32, 128, false, false); + Quant({1}, 64, 128, 32, false, true); + Quant({1}, 64, 32, 128, false, true); + Quant({2, 3}, 64, 32, 128, false, true); + Quant({1}, 64, 128, 32, true, false); + Quant({1}, 64, 32, 128, true, false); + Quant({2, 3}, 64, 32, 128, true, false); + Quant({1}, 64, 128, 32, true, true); + Quant({1}, 64, 32, 128, true, true); + Quant({2, 3}, 64, 32, 128, true, true); + // UnAligned + Quant({2}, 3, 3, 3, false, false); + Quant({16}, 31, 61, 67, false, true); + Quant({31}, 31, 61, 67, true, false); + Quant({2, 3}, 31, 61, 67, true, true); +} + // TODO(liyin): test transpose after implementing gpu runtime // now transpose test is in kernels_test diff --git a/mace/public/mace_runtime.h b/mace/public/mace_runtime.h index c1e158ab..f97a2cf7 100644 --- a/mace/public/mace_runtime.h +++ b/mace/public/mace_runtime.h @@ -136,14 +136,13 @@ void SetGPUHints(GPUPerfHint perf_hint, GPUPriorityHint priority_hint); /// also be truncated to the corresponding cores number when num_threads_hint /// is larger than it. /// The OpenMP threads will be bind to (via sched_setaffinity) big cores -/// (AFFINITY_BIG_ONLY) and little cores (AFFINITY_LITTLE_ONLY). +/// (AFFINITY_BIG_ONLY) or little cores (AFFINITY_LITTLE_ONLY). /// /// \param num_threads_hint it is only a hint. /// \param policy one of CPUAffinityPolicy -/// \param status MACE_SUCCESS for successful, or it can't reliabley -/// detect big-LITTLE cores (see GetBigLittleCoreIDs). In such cases, it's -/// suggested to use AFFINITY_NONE to use all cores. -/// \return +/// \return MACE_SUCCESS for success, or it can't reliably detect big-LITTLE +/// cores (see GetBigLittleCoreIDs). In such cases, it's suggested to use +/// AFFINITY_NONE to use all cores. __attribute__((visibility("default"))) MaceStatus SetOpenMPThreadPolicy(int num_threads_hint, CPUAffinityPolicy policy); @@ -159,7 +158,6 @@ MaceStatus SetOpenMPThreadPolicy(int num_threads_hint, /// /// \param num_threads /// \param cpu_ids -/// \param status /// \return __attribute__((visibility("default"))) MaceStatus SetOpenMPThreadAffinity(int num_threads, @@ -180,13 +178,25 @@ __attribute__((visibility("default"))) MaceStatus GetBigLittleCoreIDs(std::vector *big_core_ids, std::vector *little_core_ids); -// Set gemmlowp threads number and processor affinity. -// gemmlowp is used by mace for quantization. -// Caution: this function may hurt performance if improper parameters provided. -// -// This function may not work well on some chips (e.g. MTK). Setting thread -// affinity to offline cores may run very slow or unexpectedly. In such cases, -// please use SetGemmlowpThreadPolicy with default policy instead. +/// \brief Set gemmlowp threads number and affinity policy for quantization. +/// +/// Caution: this function may hurt performance if improper parameters provided. +/// gemmlowp shares threads with OpenMP, which are set by SetOpenMPThreadPolicy, +/// so affinity policy set by these two functions should be the same. +/// When num_threads_hint is zero or negative, +/// the function will set the threads number equaling to the number of +/// big (AFFINITY_BIG_ONLY), little (AFFINITY_LITTLE_ONLY) or all +/// (AFFINITY_NONE) cores according to the policy. The threads number will +/// also be truncated to the corresponding cores number when num_threads_hint +/// is larger than it. +/// The gemmlowp threads will be bind to (via sched_setaffinity) big cores +/// (AFFINITY_BIG_ONLY) or little cores (AFFINITY_LITTLE_ONLY). +/// +/// \param num_threads_hint it is only a hint. +/// \param policy one of CPUAffinityPolicy +/// \return MACE_SUCCESS for success, or it can't reliably detect big-LITTLE +/// cores (see GetBigLittleCoreIDs). In such cases, it's suggested to use +/// AFFINITY_NONE to use all cores. __attribute__((visibility("default"))) MaceStatus SetGemmlowpThreadPolicy(int num_threads_hint, CPUAffinityPolicy policy); -- GitLab