提交 310d287b 编写于 作者: 李滨

Merge branch 'u8gemm' into 'master'

Optimize low precision gemv

See merge request !882
......@@ -64,7 +64,8 @@ cc_library(
"ops_test_util.h",
"fixpoint.h",
"gemmlowp_util.h",
]
"arm/fixpoint_*.h",
],
) + if_opencl_enabled(glob([
"opencl/*.h",
"opencl/image/*.h",
......@@ -72,6 +73,7 @@ cc_library(
])) + if_quantize_enabled(glob([
"fixpoint.h",
"gemmlowp_util.h",
"arm/fixpoint_*.h",
])),
copts = [
"-Werror",
......@@ -101,11 +103,10 @@ cc_library(
]),
)
cc_library(
name = "ops",
srcs = [
"ops_registry.cc"
"ops_registry.cc",
],
hdrs = [
"ops_registry.h",
......@@ -138,12 +139,12 @@ cc_library(
cc_library(
name = "test",
testonly = 1,
hdrs = glob([
"*_test_util.h",
]),
srcs = [
"ops_test_util.cc",
],
hdrs = glob([
"*_test_util.h",
]),
copts = [
"-Werror",
"-Wextra",
......@@ -174,13 +175,12 @@ cc_test(
"opencl/*_test.cc",
],
exclude = [
"fixpoint_test.cc"
"fixpoint_test.cc",
],
) + if_quantize_enabled(glob(
[
"fixpoint_test.cc"
"fixpoint_test.cc",
],
)),
copts = [
"-Werror",
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FIXPOINT_GEMM_H_
#define MACE_OPS_ARM_FIXPOINT_GEMM_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace {
namespace ops {
template<typename INPUT_TYPE, typename OUTPUT_TYPE>
void FixPointGemv(const INPUT_TYPE *lhs,
const INPUT_TYPE *rhs,
const int lhs_zero_point,
const int rhs_zero_point,
const index_t lhs_height,
const index_t lhs_width,
OUTPUT_TYPE *result);
template<>
void FixPointGemv<uint8_t, int32_t>(const uint8_t *lhs,
const uint8_t *rhs,
const int lhs_zero_point,
const int rhs_zero_point,
const index_t lhs_height,
const index_t lhs_width,
int32_t *result) {
int32_t zero_point_dot = lhs_zero_point * rhs_zero_point * lhs_width;
uint32_t sum_rhs = 0;
for (index_t i = 0; i < lhs_width; ++i) {
sum_rhs += rhs[i];
}
#pragma omp parallel for
for (index_t h = 0; h < lhs_height; ++h) {
const uint8_t *lhs_ptr = lhs + h * lhs_width;
const uint8_t *rhs_ptr = rhs;
int32_t *ret_ptr = result + h;
uint32_t dot = 0;
uint32_t sum_lhs = 0;
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
uint32x4_t vo0_high_u32, vo0_low_u32, vo1_high_u32, vo1_low_u32;
vo0_high_u32 = vdupq_n_u32(0);
vo0_low_u32 = vdupq_n_u32(0);
vo1_high_u32 = vdupq_n_u32(0);
vo1_low_u32 = vdupq_n_u32(0);
uint32x4_t sum_lhs_low_u32, sum_lhs_high_u32;
sum_lhs_low_u32 = vdupq_n_u32(0);
sum_lhs_high_u32 = vdupq_n_u32(0);
for (; w <= lhs_width - 16; w += 16) {
uint8x8_t vl0_u8, vl1_u8;
uint8x8_t vr0_u8, vr1_u8;
uint16x8_t vl0_u16, vl1_u16;
uint16x8_t vr0_u16, vr1_u16;
vl0_u8 = vld1_u8(lhs_ptr);
vl1_u8 = vld1_u8(lhs_ptr + 8);
vr0_u8 = vld1_u8(rhs_ptr);
vr1_u8 = vld1_u8(rhs_ptr + 8);
vl0_u16 = vmovl_u8(vl0_u8);
vl1_u16 = vmovl_u8(vl1_u8);
vr0_u16 = vmovl_u8(vr0_u8);
vr1_u16 = vmovl_u8(vr1_u8);
vo0_high_u32 = vmlal_u16(vo0_high_u32,
vget_high_u16(vl0_u16),
vget_high_u16(vr0_u16));
vo0_low_u32 = vmlal_u16(vo0_low_u32,
vget_low_u16(vl0_u16),
vget_low_u16(vr0_u16));
vo1_high_u32 = vmlal_u16(vo1_high_u32,
vget_high_u16(vl1_u16),
vget_high_u16(vr1_u16));
vo1_low_u32 = vmlal_u16(vo1_low_u32,
vget_low_u16(vl1_u16),
vget_low_u16(vr1_u16));
// It can be precuculated if lhs is const, but for this case
// computation is not bottleneck
sum_lhs_high_u32 += vaddl_u16(vget_high_u16(vl0_u16),
vget_high_u16(vl1_u16));
sum_lhs_low_u32 += vaddl_u16(vget_low_u16(vl0_u16),
vget_low_u16(vl1_u16));
lhs_ptr += 16;
rhs_ptr += 16;
}
vo0_low_u32 = vaddq_u32(vo0_high_u32, vo0_low_u32);
vo1_low_u32 = vaddq_u32(vo1_high_u32, vo1_low_u32);
vo0_low_u32 = vaddq_u32(vo0_low_u32, vo1_low_u32);
dot += vaddvq_u32(vo0_low_u32);
sum_lhs_low_u32 = vaddq_u32(sum_lhs_high_u32, sum_lhs_low_u32);
sum_lhs = vaddvq_u32(sum_lhs_low_u32);
#endif // MACE_ENABLE_NEON
for (; w < lhs_width; ++w) {
dot += (*lhs_ptr) * (*rhs_ptr);
sum_lhs += (*lhs_ptr);
++lhs_ptr;
++rhs_ptr;
}
int32_t ret = dot - sum_lhs * rhs_zero_point - sum_rhs * lhs_zero_point
+ zero_point_dot;
*ret_ptr = ret;
} // h
}
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FIXPOINT_GEMM_H_
......@@ -27,6 +27,7 @@
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/arm/fixpoint_gemm.h"
#endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL
......@@ -169,9 +170,6 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
const index_t K,
const index_t width,
Tensor *C) {
auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
......@@ -180,6 +178,9 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
auto c_ptr_base = C->mutable_data<uint8_t>();
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
index_t a_size = height * K;
index_t b_size = K * width;
index_t c_size = height * width;
......@@ -213,9 +214,6 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
const index_t K,
const index_t width,
Tensor *C) {
auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
......@@ -224,24 +222,53 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
auto c_ptr_base = C->mutable_data<int32_t>();
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
index_t a_size = height * K;
index_t b_size = K * width;
index_t c_size = height * width;
const auto output_pipeline = std::make_tuple();
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, int32_t, BitDepthParams>(
gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(),
-B->zero_point(), output_pipeline);
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
// gemv
for (index_t i = 0; i < batch; ++i) {
FixPointGemv(a_ptr_base + i * height * K,
b_ptr_base + i * K,
A->zero_point(),
B->zero_point(),
height,
K,
c_ptr_base + i * height);
}
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
// gevm
for (index_t i = 0; i < batch; ++i) {
FixPointGemv(b_ptr_base + i * K * width,
a_ptr_base + i * K,
B->zero_point(),
A->zero_point(),
width,
K,
c_ptr_base + i * width);
}
} else {
auto
gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
index_t a_size = height * K;
index_t b_size = K * width;
index_t c_size = height * width;
const auto output_pipeline = std::make_tuple();
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
gemmlowp::MatrixMap <int32_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, int32_t, BitDepthParams>(
gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(),
-B->zero_point(), output_pipeline);
}
}
C->SetScale(A->scale() * B->scale());
......
......@@ -315,14 +315,20 @@ void QuantOutputInt32(const std::vector<index_t> &batch,
index_t batch_count = std::accumulate(batch.begin(), batch.end(), 1,
std::multiplies<index_t>());
if (transpose_a) {
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height});
net.AddRandomInput<CPU, float>("A", {batch_count, channels, height},
false);
} else {
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels});
net.AddRandomInput<CPU, float>("A", {batch_count, height, channels},
false);
}
if (transpose_b) {
net.AddRandomInput<CPU, float>("B", {batch_count, out_width, channels});
net.AddRandomInput<CPU, float>("B",
{batch_count, out_width, channels},
false);
} else {
net.AddRandomInput<CPU, float>("B", {batch_count, channels, out_width});
net.AddRandomInput<CPU, float>("B",
{batch_count, channels, out_width},
false);
}
OpDefBuilder("MatMul", "MatMulTest")
......@@ -411,11 +417,18 @@ TEST_F(MatMulOpTest, QuantOutputInt32) {
QuantOutputInt32({1}, 64, 128, 32, true, true);
QuantOutputInt32({1}, 64, 32, 128, true, true);
QuantOutputInt32({2, 3}, 64, 32, 128, true, true);
QuantOutputInt32({1}, 1, 30000, 256, false, true);
QuantOutputInt32({1}, 30000, 256, 1, false, false);
QuantOutputInt32({2}, 1, 256, 128, false, true);
QuantOutputInt32({3}, 128, 256, 1, false, false);
// UnAligned
QuantOutputInt32({2}, 3, 3, 3, false, false);
QuantOutputInt32({16}, 31, 61, 67, false, true);
QuantOutputInt32({31}, 31, 61, 67, true, false);
QuantOutputInt32({2, 3}, 31, 61, 67, true, true);
QuantOutputInt32({1}, 1, 30001, 253, false, true);
QuantOutputInt32({2}, 253, 300, 1, false, false);
}
// TODO(liyin): test transpose after implementing gpu runtime
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册