From d87285bf8607d8fee6b32c1daba7f4096ffc8ebc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Thu, 22 Nov 2018 14:35:44 +0800 Subject: [PATCH] Optimize low precision gemv --- mace/ops/BUILD | 18 ++--- mace/ops/arm/fixpoint_gemm.h | 141 +++++++++++++++++++++++++++++++++++ mace/ops/matmul.cc | 73 ++++++++++++------ mace/ops/matmul_test.cc | 21 +++++- 4 files changed, 217 insertions(+), 36 deletions(-) create mode 100644 mace/ops/arm/fixpoint_gemm.h diff --git a/mace/ops/BUILD b/mace/ops/BUILD index c0d4fbed..1e1efc89 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -64,7 +64,8 @@ cc_library( "ops_test_util.h", "fixpoint.h", "gemmlowp_util.h", - ] + "arm/fixpoint_*.h", + ], ) + if_opencl_enabled(glob([ "opencl/*.h", "opencl/image/*.h", @@ -72,6 +73,7 @@ cc_library( ])) + if_quantize_enabled(glob([ "fixpoint.h", "gemmlowp_util.h", + "arm/fixpoint_*.h", ])), copts = [ "-Werror", @@ -101,11 +103,10 @@ cc_library( ]), ) - cc_library( name = "ops", srcs = [ - "ops_registry.cc" + "ops_registry.cc", ], hdrs = [ "ops_registry.h", @@ -138,12 +139,12 @@ cc_library( cc_library( name = "test", testonly = 1, - hdrs = glob([ - "*_test_util.h", - ]), srcs = [ "ops_test_util.cc", ], + hdrs = glob([ + "*_test_util.h", + ]), copts = [ "-Werror", "-Wextra", @@ -174,13 +175,12 @@ cc_test( "opencl/*_test.cc", ], exclude = [ - "fixpoint_test.cc" + "fixpoint_test.cc", ], ) + if_quantize_enabled(glob( [ - "fixpoint_test.cc" + "fixpoint_test.cc", ], - )), copts = [ "-Werror", diff --git a/mace/ops/arm/fixpoint_gemm.h b/mace/ops/arm/fixpoint_gemm.h new file mode 100644 index 00000000..921f7e75 --- /dev/null +++ b/mace/ops/arm/fixpoint_gemm.h @@ -0,0 +1,141 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_FIXPOINT_GEMM_H_ +#define MACE_OPS_ARM_FIXPOINT_GEMM_H_ + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) +#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) +#endif + +namespace mace { +namespace ops { + +template +void FixPointGemv(const INPUT_TYPE *lhs, + const INPUT_TYPE *rhs, + const int lhs_zero_point, + const int rhs_zero_point, + const index_t lhs_height, + const index_t lhs_width, + OUTPUT_TYPE *result); + +template<> +void FixPointGemv(const uint8_t *lhs, + const uint8_t *rhs, + const int lhs_zero_point, + const int rhs_zero_point, + const index_t lhs_height, + const index_t lhs_width, + int32_t *result) { + int32_t zero_point_dot = lhs_zero_point * rhs_zero_point * lhs_width; + + uint32_t sum_rhs = 0; + for (index_t i = 0; i < lhs_width; ++i) { + sum_rhs += rhs[i]; + } + +#pragma omp parallel for + for (index_t h = 0; h < lhs_height; ++h) { + const uint8_t *lhs_ptr = lhs + h * lhs_width; + const uint8_t *rhs_ptr = rhs; + int32_t *ret_ptr = result + h; + + uint32_t dot = 0; + uint32_t sum_lhs = 0; + index_t w = 0; + +#if defined(MACE_ENABLE_NEON) + uint32x4_t vo0_high_u32, vo0_low_u32, vo1_high_u32, vo1_low_u32; + vo0_high_u32 = vdupq_n_u32(0); + vo0_low_u32 = vdupq_n_u32(0); + vo1_high_u32 = vdupq_n_u32(0); + vo1_low_u32 = vdupq_n_u32(0); + + uint32x4_t sum_lhs_low_u32, sum_lhs_high_u32; + sum_lhs_low_u32 = vdupq_n_u32(0); + sum_lhs_high_u32 = vdupq_n_u32(0); + + for (; w <= lhs_width - 16; w += 16) { + uint8x8_t vl0_u8, vl1_u8; + uint8x8_t vr0_u8, vr1_u8; + uint16x8_t vl0_u16, vl1_u16; + uint16x8_t vr0_u16, vr1_u16; + + vl0_u8 = vld1_u8(lhs_ptr); + vl1_u8 = vld1_u8(lhs_ptr + 8); + + vr0_u8 = vld1_u8(rhs_ptr); + vr1_u8 = vld1_u8(rhs_ptr + 8); + + vl0_u16 = vmovl_u8(vl0_u8); + vl1_u16 = vmovl_u8(vl1_u8); + + vr0_u16 = vmovl_u8(vr0_u8); + vr1_u16 = vmovl_u8(vr1_u8); + + vo0_high_u32 = vmlal_u16(vo0_high_u32, + vget_high_u16(vl0_u16), + vget_high_u16(vr0_u16)); + vo0_low_u32 = vmlal_u16(vo0_low_u32, + vget_low_u16(vl0_u16), + vget_low_u16(vr0_u16)); + vo1_high_u32 = vmlal_u16(vo1_high_u32, + vget_high_u16(vl1_u16), + vget_high_u16(vr1_u16)); + vo1_low_u32 = vmlal_u16(vo1_low_u32, + vget_low_u16(vl1_u16), + vget_low_u16(vr1_u16)); + + // It can be precuculated if lhs is const, but for this case + // computation is not bottleneck + sum_lhs_high_u32 += vaddl_u16(vget_high_u16(vl0_u16), + vget_high_u16(vl1_u16)); + sum_lhs_low_u32 += vaddl_u16(vget_low_u16(vl0_u16), + vget_low_u16(vl1_u16)); + + lhs_ptr += 16; + rhs_ptr += 16; + } + vo0_low_u32 = vaddq_u32(vo0_high_u32, vo0_low_u32); + vo1_low_u32 = vaddq_u32(vo1_high_u32, vo1_low_u32); + vo0_low_u32 = vaddq_u32(vo0_low_u32, vo1_low_u32); + dot += vaddvq_u32(vo0_low_u32); + + sum_lhs_low_u32 = vaddq_u32(sum_lhs_high_u32, sum_lhs_low_u32); + sum_lhs = vaddvq_u32(sum_lhs_low_u32); +#endif // MACE_ENABLE_NEON + + for (; w < lhs_width; ++w) { + dot += (*lhs_ptr) * (*rhs_ptr); + sum_lhs += (*lhs_ptr); + ++lhs_ptr; + ++rhs_ptr; + } + + int32_t ret = dot - sum_lhs * rhs_zero_point - sum_rhs * lhs_zero_point + + zero_point_dot; + + *ret_ptr = ret; + } // h +} + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_FIXPOINT_GEMM_H_ diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index af88dd85..614788d8 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -27,6 +27,7 @@ #ifdef MACE_ENABLE_QUANTIZE #include "mace/ops/gemmlowp_util.h" +#include "mace/ops/arm/fixpoint_gemm.h" #endif // MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_OPENCL @@ -169,9 +170,6 @@ class MatMulFixpointImpl { const index_t K, const index_t width, Tensor *C) { - auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); - MACE_CHECK_NOTNULL(gemm_context); - Tensor::MappingGuard guarda(A); Tensor::MappingGuard guardb(B); Tensor::MappingGuard guardc(C); @@ -180,6 +178,9 @@ class MatMulFixpointImpl { auto c_ptr_base = C->mutable_data(); index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, std::multiplies()); + auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); + MACE_CHECK_NOTNULL(gemm_context); + index_t a_size = height * K; index_t b_size = K * width; index_t c_size = height * width; @@ -213,9 +214,6 @@ class MatMulFixpointImpl { const index_t K, const index_t width, Tensor *C) { - auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); - MACE_CHECK_NOTNULL(gemm_context); - Tensor::MappingGuard guarda(A); Tensor::MappingGuard guardb(B); Tensor::MappingGuard guardc(C); @@ -224,24 +222,53 @@ class MatMulFixpointImpl { auto c_ptr_base = C->mutable_data(); index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, std::multiplies()); - index_t a_size = height * K; - index_t b_size = K * width; - index_t c_size = height * width; - const auto output_pipeline = std::make_tuple(); - - for (index_t i = 0; i < batch; ++i) { - gemmlowp::MatrixMap - a_matrix(a_ptr_base + i * a_size, height, K); - gemmlowp::MatrixMap - b_matrix(b_ptr_base + i * b_size, K, width); - gemmlowp::MatrixMap - c_matrix(c_ptr_base + i * c_size, height, width); - - using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; - gemmlowp::GemmWithOutputPipeline( - gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), - -B->zero_point(), output_pipeline); + if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { + // gemv + for (index_t i = 0; i < batch; ++i) { + FixPointGemv(a_ptr_base + i * height * K, + b_ptr_base + i * K, + A->zero_point(), + B->zero_point(), + height, + K, + c_ptr_base + i * height); + } + } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { + // gevm + for (index_t i = 0; i < batch; ++i) { + FixPointGemv(b_ptr_base + i * K * width, + a_ptr_base + i * K, + B->zero_point(), + A->zero_point(), + width, + K, + c_ptr_base + i * width); + } + } else { + auto + gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); + MACE_CHECK_NOTNULL(gemm_context); + + index_t a_size = height * K; + index_t b_size = K * width; + index_t c_size = height * width; + + const auto output_pipeline = std::make_tuple(); + + for (index_t i = 0; i < batch; ++i) { + gemmlowp::MatrixMap + a_matrix(a_ptr_base + i * a_size, height, K); + gemmlowp::MatrixMap + b_matrix(b_ptr_base + i * b_size, K, width); + gemmlowp::MatrixMap + c_matrix(c_ptr_base + i * c_size, height, width); + + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + gemmlowp::GemmWithOutputPipeline( + gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), + -B->zero_point(), output_pipeline); + } } C->SetScale(A->scale() * B->scale()); diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index d2d95874..82187b8b 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -315,14 +315,20 @@ void QuantOutputInt32(const std::vector &batch, index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1, std::multiplies()); if (transpose_a) { - net.AddRandomInput("A", {batch_count, channels, height}); + net.AddRandomInput("A", {batch_count, channels, height}, + false); } else { - net.AddRandomInput("A", {batch_count, height, channels}); + net.AddRandomInput("A", {batch_count, height, channels}, + false); } if (transpose_b) { - net.AddRandomInput("B", {batch_count, out_width, channels}); + net.AddRandomInput("B", + {batch_count, out_width, channels}, + false); } else { - net.AddRandomInput("B", {batch_count, channels, out_width}); + net.AddRandomInput("B", + {batch_count, channels, out_width}, + false); } OpDefBuilder("MatMul", "MatMulTest") @@ -411,11 +417,18 @@ TEST_F(MatMulOpTest, QuantOutputInt32) { QuantOutputInt32({1}, 64, 128, 32, true, true); QuantOutputInt32({1}, 64, 32, 128, true, true); QuantOutputInt32({2, 3}, 64, 32, 128, true, true); + QuantOutputInt32({1}, 1, 30000, 256, false, true); + QuantOutputInt32({1}, 30000, 256, 1, false, false); + QuantOutputInt32({2}, 1, 256, 128, false, true); + QuantOutputInt32({3}, 128, 256, 1, false, false); + // UnAligned QuantOutputInt32({2}, 3, 3, 3, false, false); QuantOutputInt32({16}, 31, 61, 67, false, true); QuantOutputInt32({31}, 31, 61, 67, true, false); QuantOutputInt32({2, 3}, 31, 61, 67, true, true); + QuantOutputInt32({1}, 1, 30001, 253, false, true); + QuantOutputInt32({2}, 253, 300, 1, false, false); } // TODO(liyin): test transpose after implementing gpu runtime -- GitLab