提交 03e21542 编写于 作者: 李寅

Optimize gemv u8 and refactor ops.

上级 5efbfbff
......@@ -11,6 +11,7 @@ load(
"//mace:mace.bzl",
"if_android",
"if_neon_enabled",
"if_neon_enabled_str",
"if_openmp_enabled",
"if_android_armv7",
"if_hexagon_enabled",
......@@ -41,8 +42,8 @@ cc_library(
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
"//mace/public",
"//mace/ops",
"//mace/public",
],
alwayslink = 1,
)
......@@ -59,8 +60,8 @@ cc_binary(
linkshared = 1,
linkstatic = 0,
deps = [
"//mace/libmace:mace_version_script.lds",
"//mace/libmace",
"//mace/libmace:mace_version_script.lds",
],
)
......@@ -81,6 +82,8 @@ genrule(
srcs = [
"//mace/codegen:generated_version",
"//mace/core",
"//mace/ops:common",
"//mace/ops:ref_kernels",
"//mace/ops:internal_ops",
"//mace/ops",
"//mace/libmace",
......@@ -88,13 +91,20 @@ genrule(
"//mace/proto:mace_cc",
"@com_google_protobuf//:protobuf_lite",
] + if_opencl_enabled([
"//mace/ops:opencl_kernels",
"//mace/codegen:generated_opencl",
]) + if_neon_enabled([
"//mace/ops:arm_neon_kernels",
]),
outs = ["libmace.a"],
cmd = "tmp_mri_file=$$(mktemp mace-static-lib-mri.XXXXXXXXXX);" +
"mri_stream=$$(python $(location //mace/python/tools:archive_static_lib) " +
"$(locations //mace/codegen:generated_version) " +
"$(locations //mace/core:core) " +
"$(locations //mace/ops:common) " +
"$(locations //mace/ops:ref_kernels) " +
if_neon_enabled_str("$(locations //mace/ops:arm_neon_kernels) ") +
if_opencl_enabled_str("$(locations //mace/ops:opencl_kernels) ") +
"$(locations //mace/ops:internal_ops) " +
"$(locations //mace/ops:ops) " +
"$(locations //mace/libmace:libmace) " +
......
......@@ -17,7 +17,7 @@
#include <string>
#include <vector>
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/public/mace.h"
namespace mace {
......
......@@ -26,7 +26,7 @@
#include "mace/core/memory_optimizer.h"
#include "mace/core/net.h"
#include "mace/ops/ops_registry.h"
#include "mace/ops/transpose.h"
#include "mace/ops/common/transpose.h"
#include "mace/public/mace.h"
#ifdef MACE_ENABLE_OPENCL
......
......@@ -42,6 +42,12 @@ def if_neon_enabled(a):
"//conditions:default": [],
})
def if_neon_enabled_str(a):
return select({
"//mace:neon_enabled": a,
"//conditions:default": "",
})
def if_hexagon_enabled(a):
return select({
"//mace:hexagon_enabled": a,
......
......@@ -18,61 +18,323 @@ load(
)
cc_library(
name = "internal_ops",
name = "common",
srcs = glob(
[
"*.cc",
"arm/*.cc",
"common/*.cc",
],
),
hdrs = glob(
[
"common/*.h",
],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
"//mace/core",
],
)
cc_library(
name = "testing",
srcs = glob(
[
"testing/*.cc",
],
),
hdrs = glob(
[
"testing/*.h",
],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
"//mace/core",
"@gtest//:gtest",
],
)
cc_library(
name = "ref_kernels",
srcs = glob(
[
"ref/*.cc",
],
),
hdrs = glob(
[
"ref/*.h",
],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
":common",
"//mace/core",
],
)
# After refactor, all arm neon kernels go here.
# Could be shipped to other product use.
cc_library(
name = "arm_neon_kernels",
srcs = glob(
[
"arm/fp32/*.cc",
],
exclude = [
"*_test.cc",
"*_benchmark.cc",
"arm/*_test.cc",
"ops_registry.cc",
"ops_test_util.cc",
"buffer_transform.cc",
"lstm_cell.cc",
"quantize.cc",
"quantization_util.cc",
"arm/fp32/*_test.cc",
],
) + if_quantize_enabled(glob(
[
"arm/q8/*.cc",
],
) + if_opencl_enabled(glob(
exclude = [
"arm/q8/*_test.cc",
],
)),
hdrs = glob(
[
"arm/fp32/*.h",
],
) + if_quantize_enabled(glob(
[
"arm/q8/*.h",
],
)),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
":common",
"//mace/core",
],
)
# After refactor, all GPU OpenCL kernels go here.
# Could be shipped to other product use.
cc_library(
name = "opencl_kernels",
srcs = glob(
[
"opencl/*.cc",
"opencl/image/*.cc",
"opencl/buffer/*.cc",
"opencl/**/*.cc",
"buffer_transform.cc",
"lstm_cell.cc",
],
exclude = [
"opencl/*_test.cc",
],
)) + if_quantize_enabled(glob(
),
hdrs = glob(
[
"opencl/*.h",
"opencl/**/*.h",
],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
":common",
"//mace/core",
],
)
cc_library(
name = "arm_neon_kernels_test",
srcs = glob(
[
"arm/fp32/*_test.cc",
],
) + if_quantize_enabled(glob(
[
"arm/q8/*_test.cc",
],
)),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
":arm_neon_kernels",
":ref_kernels",
":testing",
"@gtest//:gtest",
],
alwayslink = 1,
)
cc_library(
name = "opencl_kernels_test",
srcs = glob(
[
"opencl/*_test.cc",
"opencl/**/*_test.cc",
],
),
copts = [
"-Werror",
"-Wextra",
"-Wno-missing-field-initializers",
] + if_openmp_enabled([
"-fopenmp",
]) + if_neon_enabled([
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
]) + if_quantize_enabled([
"-DMACE_ENABLE_QUANTIZE",
]) + if_hexagon_enabled([
"-DMACE_ENABLE_HEXAGON",
]),
deps = [
":opencl_kernels",
":ref_kernels",
":testing",
"@gtest//:gtest",
],
alwayslink = 1,
)
cc_library(
name = "internal_ops",
srcs = glob(
[
"*.cc",
"arm/*.cc", # remove it after refactor
],
exclude = [
"*_test.cc",
"*_benchmark.cc",
"ops_registry.cc",
"ops_test_util.cc",
"lstm_cell.cc", # TODO: move it into opencl
"buffer_transform.cc", # TODO: move it into opencl
"quantize.cc",
"quantization_util.cc",
"arm/*_test.cc", # remove it after refactor
],
)),
) + if_quantize_enabled(
glob(
[
"quantize.cc",
"quantization_util.cc",
],
),
),
hdrs = glob(
[
"*.h",
"arm/*.h",
"arm/*.h", # remove it after refactor
],
exclude = [
"ops_registry.h",
"ops_test_util.h",
"fixpoint.h",
"gemmlowp_util.h",
"arm/fixpoint_*.h",
"quantization_util.h",
],
) + if_opencl_enabled(glob([
"opencl/*.h",
"opencl/image/*.h",
"opencl/buffer/*.h",
])) + if_quantize_enabled(glob([
) + if_quantize_enabled(glob([
"fixpoint.h",
"gemmlowp_util.h",
"arm/fixpoint_*.h",
"quantization_util.h",
])),
copts = [
......@@ -85,7 +347,6 @@ cc_library(
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
]) + if_android_armv7([
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
......@@ -96,21 +357,30 @@ cc_library(
]),
linkopts = if_android(["-lm"]),
deps = [
":ref_kernels",
"//mace/core",
] + if_quantize_enabled([
"@tflite",
"@gemmlowp",
]) + if_neon_enabled([
":arm_neon_kernels",
]) + if_opencl_enabled([
":opencl_kernels",
]),
)
cc_library(
name = "ops",
srcs = [
"ops_registry.cc",
],
hdrs = [
"ops_registry.h",
],
srcs = glob(
[
"ops_registry.cc",
],
),
hdrs = glob(
[
"ops_registry.h",
],
),
copts = [
"-Werror",
"-Wextra",
......@@ -121,7 +391,6 @@ cc_library(
"-DMACE_ENABLE_NEON",
]) + if_android_armv7([
"-mfpu=neon",
]) + if_android_armv7([
"-mfloat-abi=softfp",
]) + if_opencl_enabled([
"-DMACE_ENABLE_OPENCL",
......@@ -161,6 +430,7 @@ cc_library(
]),
deps = [
"ops",
"testing",
"@gtest",
],
)
......@@ -171,8 +441,8 @@ cc_test(
srcs = glob(
[
"*_test.cc",
"arm/*_test.cc",
"opencl/*_test.cc",
"arm/*_test.cc", # remove it after refactor
"ops_test_util.cc",
],
exclude = [
"fixpoint_test.cc",
......@@ -203,9 +473,14 @@ cc_test(
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"test",
":ops",
":test",
"@gtest//:gtest_main",
],
] + if_neon_enabled([
":arm_neon_kernels_test",
]) + if_opencl_enabled([
":opencl_kernels_test",
]),
)
cc_test(
......@@ -233,7 +508,7 @@ cc_test(
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"test",
":ops",
"//mace/benchmark:statistics",
"//mace/core:test_benchmark_main",
"//third_party/eigen3",
......
......@@ -20,22 +20,13 @@
#include <string>
#include "mace/core/types.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/arm/activation_neon.h"
#include "mace/utils/logging.h"
namespace mace {
namespace ops {
enum ActivationType {
NOOP = 0,
RELU = 1,
RELUX = 2,
PRELU = 3,
TANH = 4,
SIGMOID = 5,
LEAKYRELU = 6,
};
inline ActivationType StringToActivationType(const std::string type) {
if (type == "RELU") {
return ActivationType::RELU;
......
# Notes for Contributors and Roadmap
We are going to refactor and optimize ARM related kernels one step at a time.
The code structure will be organized as that kernels for each data type are separated.
By this way, they are independent to each other and can be linked and shipped as a submodule,
and it saves us from writing macro boilerplate. The reason we do not use a unified header file for each kernel
is that we are not forcing developers to use exact same interface for kernels of different data types at this level,
although doing this is recommended if convenient to do so. A reference version is put right in `ops` directory to be used
as non-NEON kernels, which are not separated for different data types for simplicity.
Although interface is kept flexible and can be defined according to demand,
input/output parameters should be of `Tensor` type instead of raw pointer, as
`Tensor` has more information kernel might use.
......@@ -15,7 +15,6 @@
#include <algorithm>
#include "mace/ops/arm/conv_winograd.h"
#include "mace/ops/gemm.h"
namespace mace {
namespace ops {
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FIXPOINT_GEMM_H_
#define MACE_OPS_ARM_FIXPOINT_GEMM_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace {
namespace ops {
template<typename INPUT_TYPE, typename OUTPUT_TYPE>
void FixPointGemv(const INPUT_TYPE *lhs,
const INPUT_TYPE *rhs,
const int lhs_zero_point,
const int rhs_zero_point,
const index_t lhs_height,
const index_t lhs_width,
OUTPUT_TYPE *result);
template<>
void FixPointGemv<uint8_t, int32_t>(const uint8_t *lhs,
const uint8_t *rhs,
const int lhs_zero_point,
const int rhs_zero_point,
const index_t lhs_height,
const index_t lhs_width,
int32_t *result) {
int32_t zero_point_dot = lhs_zero_point * rhs_zero_point * lhs_width;
uint32_t sum_rhs = 0;
for (index_t i = 0; i < lhs_width; ++i) {
sum_rhs += rhs[i];
}
#pragma omp parallel for
for (index_t h = 0; h < lhs_height; ++h) {
const uint8_t *lhs_ptr = lhs + h * lhs_width;
const uint8_t *rhs_ptr = rhs;
int32_t *ret_ptr = result + h;
uint32_t dot = 0;
uint32_t sum_lhs = 0;
index_t w = 0;
#if defined(MACE_ENABLE_NEON)
uint32x4_t vo0_high_u32, vo0_low_u32, vo1_high_u32, vo1_low_u32;
vo0_high_u32 = vdupq_n_u32(0);
vo0_low_u32 = vdupq_n_u32(0);
vo1_high_u32 = vdupq_n_u32(0);
vo1_low_u32 = vdupq_n_u32(0);
uint32x4_t sum_lhs_low_u32, sum_lhs_high_u32;
sum_lhs_low_u32 = vdupq_n_u32(0);
sum_lhs_high_u32 = vdupq_n_u32(0);
for (; w <= lhs_width - 16; w += 16) {
uint8x8_t vl0_u8, vl1_u8;
uint8x8_t vr0_u8, vr1_u8;
uint16x8_t vl0_u16, vl1_u16;
uint16x8_t vr0_u16, vr1_u16;
vl0_u8 = vld1_u8(lhs_ptr);
vl1_u8 = vld1_u8(lhs_ptr + 8);
vr0_u8 = vld1_u8(rhs_ptr);
vr1_u8 = vld1_u8(rhs_ptr + 8);
vl0_u16 = vmovl_u8(vl0_u8);
vl1_u16 = vmovl_u8(vl1_u8);
vr0_u16 = vmovl_u8(vr0_u8);
vr1_u16 = vmovl_u8(vr1_u8);
vo0_high_u32 = vmlal_u16(vo0_high_u32,
vget_high_u16(vl0_u16),
vget_high_u16(vr0_u16));
vo0_low_u32 = vmlal_u16(vo0_low_u32,
vget_low_u16(vl0_u16),
vget_low_u16(vr0_u16));
vo1_high_u32 = vmlal_u16(vo1_high_u32,
vget_high_u16(vl1_u16),
vget_high_u16(vr1_u16));
vo1_low_u32 = vmlal_u16(vo1_low_u32,
vget_low_u16(vl1_u16),
vget_low_u16(vr1_u16));
// It can be precuculated if lhs is const, but for this case
// computation is not bottleneck
sum_lhs_high_u32 += vaddl_u16(vget_high_u16(vl0_u16),
vget_high_u16(vl1_u16));
sum_lhs_low_u32 += vaddl_u16(vget_low_u16(vl0_u16),
vget_low_u16(vl1_u16));
lhs_ptr += 16;
rhs_ptr += 16;
}
vo0_low_u32 = vaddq_u32(vo0_high_u32, vo0_low_u32);
vo1_low_u32 = vaddq_u32(vo1_high_u32, vo1_low_u32);
vo0_low_u32 = vaddq_u32(vo0_low_u32, vo1_low_u32);
dot += vaddvq_u32(vo0_low_u32);
sum_lhs_low_u32 = vaddq_u32(sum_lhs_high_u32, sum_lhs_low_u32);
sum_lhs = vaddvq_u32(sum_lhs_low_u32);
#endif // MACE_ENABLE_NEON
for (; w < lhs_width; ++w) {
dot += (*lhs_ptr) * (*rhs_ptr);
sum_lhs += (*lhs_ptr);
++lhs_ptr;
++rhs_ptr;
}
int32_t ret = dot - sum_lhs * rhs_zero_point - sum_rhs * lhs_zero_point
+ zero_point_dot;
*ret_ptr = ret;
} // h
}
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FIXPOINT_GEMM_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/fp32/gemv.h"
#include <arm_neon.h>
#include <algorithm>
#if !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
MaceStatus Gemv::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output) {
MACE_UNUSED(context);
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
const index_t h_block_size = 4;
const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size);
const index_t w_block_size = 8;
const index_t w_block_count = lhs_width / w_block_size;
const index_t w_remain = lhs_width - w_block_size * w_block_count;
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) {
// TODO(liyin): it can be put it outside the loop,
// but openmp limits param count
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
const float *bias_data = nullptr;
if (bias) {
bias_data = bias->data<float>();
}
float *output_data = output->mutable_data<float>();
const float
*lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size;
const float *rhs_ptr = rhs_data + b * lhs_width;
float
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size;
const index_t h_block_len =
std::min(h_block_size, lhs_height - h_block_idx * h_block_size);
const index_t h_offset = h_block_idx * h_block_size;
if (h_block_len == 4) {
float32x4_t vo0 = vdupq_n_f32(0);
float32x4_t vo1 = vdupq_n_f32(0);
float32x4_t vo2 = vdupq_n_f32(0);
float32x4_t vo3 = vdupq_n_f32(0);
index_t r_w_block_count = w_block_count;
// just make compiler happy
MACE_UNUSED(r_w_block_count);
// Register layout: (4x8) x (8,1)
//
// +----+
// |d16 |
// | . |
// Rhs +----+
// |d17 |
// | . |
// +----+
// |d18 |
// | . |
// +----+
// |d19 |
// | . |
// +----+
//
// | |
//
// Lhs | |
//
// +------+------+----+-----+ - - - - +----+
// | d0 . | d1 .| d2 .| d3 .| |vo0 |
// | d4 . | d5 .| d6 .| d7 .| |vo1 |
// | d8 . | d9 .| d10.| d11.| |vo2 |
// | d12. | d13.| d14.| d15.| |vo3 |
// +------+-----+-----+-----+ - - - - +----+
//
// Accumulator
//
#if not defined(__aarch64__)
asm volatile(
"cmp %[r_w_block_count], #0\n"
"beq 0f\n"
"lsl r5, %[lhs_width], #2\n"
"mov r0, %[rhs_ptr]\n"
"mov r1, %[lhs_ptr]\n"
"add r2, r1, r5\n"
"add r3, r2, r5\n"
"add r4, r3, r5\n"
// prelogue
"vld1.f32 {d16-d17}, [r0]!\n"
"vld1.f32 {d18-d19}, [r0]!\n"
"vld1.f32 {d0-d1}, [r1]!\n"
"vld1.f32 {d2-d3}, [r1]!\n"
"vld1.f32 {d4-d5}, [r2]!\n"
"vld1.f32 {d6-d7}, [r2]!\n"
"vld1.f32 {d8-d9}, [r3]!\n"
"vld1.f32 {d10-d11}, [r3]!\n"
"vld1.f32 {d12-d13}, [r4]!\n"
"vld1.f32 {d14-d15}, [r4]!\n"
"subs %[r_w_block_count], #1\n"
"beq 1f\n"
"2: \n"
"vmla.f32 %q[vo0], q0, q8\n"
"vmla.f32 %q[vo1], q2, q8\n"
"vmla.f32 %q[vo2], q4, q8\n"
"vmla.f32 %q[vo3], q6, q8\n"
"vmla.f32 %q[vo0], q1, q9\n"
"vmla.f32 %q[vo1], q3, q9\n"
"vmla.f32 %q[vo2], q5, q9\n"
"vmla.f32 %q[vo3], q7, q9\n"
"subs %[r_w_block_count], #1\n"
"vld1.f32 {d0-d1}, [r1]!\n"
"vld1.f32 {d4-d5}, [r2]!\n"
"vld1.f32 {d8-d9}, [r3]!\n"
"vld1.f32 {d12-d13}, [r4]!\n"
"vld1.f32 {d16-d17}, [r0]!\n"
"vld1.f32 {d2-d3}, [r1]!\n"
"vld1.f32 {d6-d7}, [r2]!\n"
"vld1.f32 {d10-d11}, [r3]!\n"
"vld1.f32 {d14-d15}, [r4]!\n"
"vld1.f32 {d18-d19}, [r0]!\n"
"bne 2b\n"
// prologue
"1:\n"
"vmla.f32 %q[vo0], q0, q8\n"
"vmla.f32 %q[vo1], q2, q8\n"
"vmla.f32 %q[vo2], q4, q8\n"
"vmla.f32 %q[vo3], q6, q8\n"
"vmla.f32 %q[vo0], q1, q9\n"
"vmla.f32 %q[vo1], q3, q9\n"
"vmla.f32 %q[vo2], q5, q9\n"
"vmla.f32 %q[vo3], q7, q9\n"
"0:\n"
: // outputs
[vo0] "+w"(vo0),
[vo1] "+w"(vo1),
[vo2] "+w"(vo2),
[vo3] "+w"(vo3),
[r_w_block_count] "+r"(r_w_block_count)
: // inputs
[lhs_ptr] "r"(lhs_ptr), [rhs_ptr] "r"(rhs_ptr),
[lhs_width] "r"(lhs_width)
: // clobbers
"cc", "memory", "r0", "r1", "r2", "r3", "r4", "r5",
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21");
lhs_ptr += w_block_count * w_block_size;
rhs_ptr += w_block_count * w_block_size;
#else
for (index_t w_block_index = 0; w_block_index < w_block_count;
++w_block_index) {
float32x4_t vr0 = vld1q_f32(rhs_ptr);
float32x4_t vr0n = vld1q_f32(rhs_ptr + 4);
float32x4_t vl0 = vld1q_f32(lhs_ptr);
float32x4_t vl0n = vld1q_f32(lhs_ptr + 4);
vo0 = vmlaq_f32(vo0, vl0, vr0);
vo0 = vmlaq_f32(vo0, vl0n, vr0n);
const float *lhs_ptr1 = lhs_ptr + lhs_width;
float32x4_t vl1 = vld1q_f32(lhs_ptr1);
float32x4_t vl1n = vld1q_f32(lhs_ptr1 + 4);
vo1 = vmlaq_f32(vo1, vl1, vr0);
vo1 = vmlaq_f32(vo1, vl1n, vr0n);
const float *lhs_ptr2 = lhs_ptr1 + lhs_width;
float32x4_t vl2 = vld1q_f32(lhs_ptr2);
float32x4_t vl2n = vld1q_f32(lhs_ptr2 + 4);
vo2 = vmlaq_f32(vo2, vl2, vr0);
vo2 = vmlaq_f32(vo2, vl2n, vr0n);
const float *lhs_ptr3 = lhs_ptr2 + lhs_width;
float32x4_t vl3 = vld1q_f32(lhs_ptr3);
float32x4_t vl3n = vld1q_f32(lhs_ptr3 + 4);
vo3 = vmlaq_f32(vo3, vl3, vr0);
vo3 = vmlaq_f32(vo3, vl3n, vr0n);
lhs_ptr += 8;
rhs_ptr += 8;
}
#endif // __aarch64__
float32x4_t vo = {
vaddvq_f32(vo0),
vaddvq_f32(vo1),
vaddvq_f32(vo2),
vaddvq_f32(vo3)
};
for (index_t w = 0; w < w_remain; ++w) {
vo[0] += lhs_ptr[0] * rhs_ptr[0];
vo[1] += lhs_ptr[lhs_width] * rhs_ptr[0];
vo[2] += lhs_ptr[lhs_width * 2] * rhs_ptr[0];
vo[3] += lhs_ptr[lhs_width * 3] * rhs_ptr[0];
++lhs_ptr;
++rhs_ptr;
}
float32x4_t vbias = vdupq_n_f32(0);
if (bias) {
vbias = vld1q_f32(bias_data + h_offset);
}
vo = vaddq_f32(vo, vbias);
vst1q_f32(ret_ptr, vo);
} else { // h_block_len < 4
// TODO(liyin): handle here case by case (1,2,3) to accelerate
const float *tmp_lhs_ptr = lhs_ptr;
const float *tmp_rhs_ptr = rhs_ptr;
for (index_t h = 0; h < h_block_len; ++h) {
lhs_ptr = tmp_lhs_ptr + h * lhs_width;
rhs_ptr = tmp_rhs_ptr;
float32x4_t vo0 = vdupq_n_f32(0);
for (index_t w = 0; w < w_block_count; ++w) {
float32x4_t vr0 = vld1q_f32(rhs_ptr);
float32x4_t vr0n = vld1q_f32(rhs_ptr + 4);
float32x4_t vl0 = vld1q_f32(lhs_ptr);
float32x4_t vl0n = vld1q_f32(lhs_ptr + 4);
vo0 = vmlaq_f32(vo0, vl0, vr0);
vo0 = vmlaq_f32(vo0, vl0n, vr0n);
lhs_ptr += 8;
rhs_ptr += 8;
} // w
float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_offset + h] : 0);
for (index_t w = 0; w < w_remain; ++w) {
s0 += lhs_ptr[0] * rhs_ptr[0];
++lhs_ptr;
++rhs_ptr;
} // w
ret_ptr[h] = s0;
} // h
} // if
} // h_block_idx
} // b
return MaceStatus::MACE_SUCCESS;
}
#if defined(vaddvq_f32)
#undef vaddvq_f32
#endif
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FP32_GEMV_H_
#define MACE_OPS_ARM_FP32_GEMV_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
namespace mace {
namespace ops {
namespace arm {
namespace fp32 {
class Gemv {
public:
Gemv() {}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output);
};
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_FP32_GEMV_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/fp32/gemv.h"
#include "mace/ops/ref/gemv.h"
#include "mace/ops/testing/test_utils.h"
namespace mace {
namespace ops {
namespace test {
void TestGemvFloat32(const index_t batch,
const index_t height,
const index_t width,
const bool lhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor bias(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor output(GetCPUAllocator(), DataType::DT_FLOAT);
lhs.Resize({lhs_batched ? batch : 1, height, width});
rhs.Resize({batch, width});
bias.Resize({height});
output.Resize({batch, height});
{
Tensor::MappingGuard lhs_guard(&lhs);
Tensor::MappingGuard rhs_guard(&rhs);
Tensor::MappingGuard bias_guard(&bias);
float *lhs_data = lhs.mutable_data<float>();
float *rhs_data = rhs.mutable_data<float>();
float *bias_data = bias.mutable_data<float>();
GenerateRandomRealTypeData<float>(lhs.shape(), lhs_data);
GenerateRandomRealTypeData<float>(rhs.shape(), rhs_data);
GenerateRandomRealTypeData<float>(bias.shape(), bias_data);
}
::mace::ops::arm::fp32::Gemv gemv;
gemv.Compute(nullptr,
&lhs,
&rhs,
&bias,
batch,
height,
width,
lhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT);
expected_output.Resize({batch, height});
::mace::ops::ref::Gemv<float> gemv_ref;
gemv_ref.Compute(nullptr,
&lhs,
&rhs,
&bias,
batch,
height,
width,
lhs_batched,
&expected_output);
Tensor::MappingGuard output_guard(&output);
Tensor::MappingGuard expected_guard(&expected_output);
const float *output_data = output.data<float>();
const float *expected_data = expected_output.data<float>();
for (index_t i = 0; i < output.size(); ++i) {
EXPECT_NEAR(expected_data[i], output_data[i], 0.001);
}
}
TEST(ArmGemv, TestGemvFloat32) {
TestGemvFloat32(1, 16, 4, true);
TestGemvFloat32(1, 16, 256, true);
TestGemvFloat32(2, 16, 256, true);
TestGemvFloat32(3, 63, 257, true);
TestGemvFloat32(1, 16, 4, false);
TestGemvFloat32(1, 16, 256, false);
TestGemvFloat32(2, 16, 256, false);
TestGemvFloat32(3, 63, 257, false);
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/arm/q8/gemv.h"
#include <arm_neon.h>
#include <algorithm>
#include "mace/utils/utils.h"
#include "mace/utils/quantize.h"
#if !defined(__aarch64__)
#define vmlal_high_s16(c, a, b) vmlal_s16(c, vget_high_s16(a), vget_high_s16(b))
#define vaddvq_s32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace {
namespace ops {
namespace arm {
namespace q8 {
template<typename OUTPUT_TYPE>
MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output) {
MACE_UNUSED(context);
bool is_output_type_uint8 =
DataTypeToEnum<OUTPUT_TYPE>::value == DataType::DT_UINT8;
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
float output_multiplier_float = 0.0;
int32_t output_multiplier = 0;
int32_t output_shift = 0;
if (is_output_type_uint8) {
MACE_CHECK(output->scale() > 0, "output scale must not be zero");
output_multiplier_float = lhs->scale() * rhs->scale() / output->scale();
GetOutputMultiplierAndShift(lhs->scale(),
rhs->scale(),
output->scale(),
&output_multiplier,
&output_shift);
}
const index_t h_block_size = 4;
const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size);
#pragma omp parallel for collapse(2) schedule(runtime)
for (index_t b = 0; b < batch; ++b) {
for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) {
// TODO(liyin): it can be put it outside the loop,
// but openmp limits param count
const index_t w_block_size = 16;
const index_t w_block_count = lhs_width / w_block_size;
const index_t w_remain = lhs_width - w_block_size * w_block_count;
uint8_t lhs_zero_point = static_cast<uint8_t>(lhs->zero_point());
uint8_t rhs_zero_point = static_cast<uint8_t>(rhs->zero_point());
const uint8_t *lhs_data = lhs->data<uint8_t>();
const uint8_t *rhs_data = rhs->data<uint8_t>();
const int32_t *bias_data = nullptr;
if (bias) {
bias_data = bias->data<int32_t>();
}
OUTPUT_TYPE *output_data = output->mutable_data<OUTPUT_TYPE>();
int32x4_t voutput_multiplier = vdupq_n_s32(output_multiplier);
int32x4_t voutput_shift_left = vdupq_n_s32(-output_shift);
uint8x8_t
vlhs_zero_point = vdup_n_u8(lhs_zero_point);
uint8x8_t
vrhs_zero_point = vdup_n_u8(rhs_zero_point);
const uint8_t
*lhs_ptr = lhs_data
+ static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ lhs_width * h_block_idx * h_block_size;
const uint8_t *rhs_ptr = rhs_data + b * lhs_width;
OUTPUT_TYPE
*ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size;
const index_t h_block_len =
std::min(h_block_size, lhs_height - h_block_idx * h_block_size);
const index_t h_offset = h_block_idx * h_block_size;
if (h_block_len == 4) {
int32x4_t vo0 = vdupq_n_s32(0);
int32x4_t vo1 = vdupq_n_s32(0);
int32x4_t vo2 = vdupq_n_s32(0);
int32x4_t vo3 = vdupq_n_s32(0);
index_t r_w_block_count = w_block_count;
// just make compiler happy
MACE_UNUSED(r_w_block_count);
// Register layout: (4x16) x (16x1)
//
// +----+
// |d16 |
// | . |
// | . |
// | . |
// Rhs +----+
// |d17 |
// | . |
// | . |
// | . |
// +----+
// |d18 |
// | . |
// | . |
// | . |
// +----+
// |d19 |
// | . |
// | . |
// | . |
// +----+
//
// | |
//
// Lhs | |
//
// +--------+--------+--------+--------+ - - - - +----+
// | d0 ... | d1 ... | d2 ... | d3 ... | |vo0 |
// | d4 ... | d5 ... | d6 ... | d7 ... | |vo1 |
// | d8 ... | d9 ... | d10... | d11... | |vo2 |
// | d12... | d13... | d14... | d15... | |vo3 |
// +--------+--------+--------+--------+ - - - - +----+
//
// Accumulator
//
#if not defined(__aarch64__)
asm volatile(
"cmp %[r_w_block_count], #0\n"
"beq 0f\n"
"mov r0, %[rhs_ptr]\n"
"mov r1, %[lhs_ptr]\n"
"add r2, r1, %[lhs_width]\n"
"add r3, r2, %[lhs_width]\n"
"add r4, r3, %[lhs_width]\n"
"vdup.u8 d20, %[rhs_zero_point]\n"
"vdup.u8 d21, %[lhs_zero_point]\n"
// prelogue
"vld1.8 d16, [r0]!\n"
"vld1.8 d18, [r0]!\n"
"vld1.8 d0, [r1]!\n"
"vld1.8 d2, [r1]!\n"
"vld1.8 d4, [r2]!\n"
"vld1.8 d6, [r2]!\n"
"vld1.8 d8, [r3]!\n"
"vld1.8 d10, [r3]!\n"
"vld1.8 d12, [r4]!\n"
"vld1.8 d14, [r4]!\n"
"subs %[r_w_block_count], #1\n"
"beq 1f\n"
"2: \n"
"vsubl.u8 q8, d16, d20\n"
"vsubl.u8 q9, d18, d20\n"
"vsubl.u8 q0, d0, d21\n"
"vsubl.u8 q1, d2, d21\n"
"vsubl.u8 q2, d4, d21\n"
"vsubl.u8 q3, d6, d21\n"
"vsubl.u8 q4, d8, d21\n"
"vsubl.u8 q5, d10, d21\n"
"vsubl.u8 q6, d12, d21\n"
"vsubl.u8 q7, d14, d21\n"
"vmlal.s16 %q[vo0], d0, d16\n"
"vmlal.s16 %q[vo1], d4, d16\n"
"vmlal.s16 %q[vo2], d8, d16\n"
"vmlal.s16 %q[vo3], d12, d16\n"
"vld1.8 d0, [r1]!\n"
"vld1.8 d4, [r2]!\n"
"vld1.8 d8, [r3]!\n"
"vld1.8 d12, [r4]!\n"
"vld1.8 d16, [r0]!\n"
"vmlal.s16 %q[vo0], d2, d18\n"
"vmlal.s16 %q[vo1], d6, d18\n"
"vmlal.s16 %q[vo2], d10, d18\n"
"vmlal.s16 %q[vo3], d14, d18\n"
"vld1.8 d2, [r1]!\n"
"vld1.8 d6, [r2]!\n"
"vld1.8 d10, [r3]!\n"
"vld1.8 d14, [r4]!\n"
"vld1.8 d18, [r0]!\n"
"vmlal.s16 %q[vo0], d1, d17\n"
"vmlal.s16 %q[vo1], d5, d17\n"
"vmlal.s16 %q[vo2], d9, d17\n"
"vmlal.s16 %q[vo3], d13, d17\n"
"subs %[r_w_block_count], #1\n"
"vmlal.s16 %q[vo0], d3, d19\n"
"vmlal.s16 %q[vo1], d7, d19\n"
"vmlal.s16 %q[vo2], d11, d19\n"
"vmlal.s16 %q[vo3], d15, d19\n"
"bne 2b\n"
// prologue
"1:\n"
"vsubl.u8 q8, d16, d20\n"
"vsubl.u8 q9, d18, d20\n"
"vsubl.u8 q0, d0, d21\n"
"vsubl.u8 q1, d2, d21\n"
"vsubl.u8 q2, d4, d21\n"
"vsubl.u8 q3, d6, d21\n"
"vsubl.u8 q4, d8, d21\n"
"vsubl.u8 q5, d10, d21\n"
"vsubl.u8 q6, d12, d21\n"
"vsubl.u8 q7, d14, d21\n"
"vmlal.s16 %q[vo0], d0, d16\n"
"vmlal.s16 %q[vo1], d4, d16\n"
"vmlal.s16 %q[vo2], d8, d16\n"
"vmlal.s16 %q[vo3], d12, d16\n"
"vmlal.s16 %q[vo0], d1, d17\n"
"vmlal.s16 %q[vo1], d5, d17\n"
"vmlal.s16 %q[vo2], d9, d17\n"
"vmlal.s16 %q[vo3], d13, d17\n"
"vmlal.s16 %q[vo0], d2, d18\n"
"vmlal.s16 %q[vo1], d6, d18\n"
"vmlal.s16 %q[vo2], d10, d18\n"
"vmlal.s16 %q[vo3], d14, d18\n"
"vmlal.s16 %q[vo0], d3, d19\n"
"vmlal.s16 %q[vo1], d7, d19\n"
"vmlal.s16 %q[vo2], d11, d19\n"
"vmlal.s16 %q[vo3], d15, d19\n"
"0:\n"
: // outputs
[vo0] "+w"(vo0),
[vo1] "+w"(vo1),
[vo2] "+w"(vo2),
[vo3] "+w"(vo3),
[r_w_block_count] "+r"(r_w_block_count)
: // inputs
[lhs_ptr] "r"(lhs_ptr), [rhs_ptr] "r"(rhs_ptr),
[lhs_width] "r"(lhs_width),
[lhs_zero_point] "r"(lhs_zero_point),
[rhs_zero_point] "r"(rhs_zero_point)
: // clobbers
"cc", "memory", "r0", "r1", "r2", "r3", "r4",
"d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10",
"d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20",
"d21");
lhs_ptr += w_block_count * w_block_size;
rhs_ptr += w_block_count * w_block_size;
#else
for (index_t w_block_index = 0; w_block_index < w_block_count;
++w_block_index) {
uint8x8_t vr0 = vld1_u8(rhs_ptr);
int16x8_t
vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point));
uint8x8_t vr0n = vld1_u8(rhs_ptr + 8);
int16x8_t
vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point));
uint8x8_t vl0 = vld1_u8(lhs_ptr);
int16x8_t
vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point));
uint8x8_t vl0n = vld1_u8(lhs_ptr + 8);
int16x8_t
vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point));
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0));
vo0 = vmlal_high_s16(vo0, vxl0, vxr0);
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n));
vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n);
const uint8_t *lhs_ptr1 = lhs_ptr + lhs_width;
uint8x8_t vl1 = vld1_u8(lhs_ptr1);
int16x8_t
vxl1 = vreinterpretq_s16_u16(vsubl_u8(vl1, vlhs_zero_point));
uint8x8_t vl1n = vld1_u8(lhs_ptr1 + 8);
int16x8_t
vxl1n = vreinterpretq_s16_u16(vsubl_u8(vl1n, vlhs_zero_point));
vo1 = vmlal_s16(vo1, vget_low_s16(vxl1), vget_low_s16(vxr0));
vo1 = vmlal_high_s16(vo1, vxl1, vxr0);
vo1 = vmlal_s16(vo1, vget_low_s16(vxl1n), vget_low_s16(vxr0n));
vo1 = vmlal_high_s16(vo1, vxl1n, vxr0n);
const uint8_t *lhs_ptr2 = lhs_ptr1 + lhs_width;
uint8x8_t vl2 = vld1_u8(lhs_ptr2);
int16x8_t
vxl2 = vreinterpretq_s16_u16(vsubl_u8(vl2, vlhs_zero_point));
uint8x8_t vl2n = vld1_u8(lhs_ptr2 + 8);
int16x8_t
vxl2n = vreinterpretq_s16_u16(vsubl_u8(vl2n, vlhs_zero_point));
vo2 = vmlal_s16(vo2, vget_low_s16(vxl2), vget_low_s16(vxr0));
vo2 = vmlal_high_s16(vo2, vxl2, vxr0);
vo2 = vmlal_s16(vo2, vget_low_s16(vxl2n), vget_low_s16(vxr0n));
vo2 = vmlal_high_s16(vo2, vxl2n, vxr0n);
const uint8_t *lhs_ptr3 = lhs_ptr2 + lhs_width;
uint8x8_t vl3 = vld1_u8(lhs_ptr3);
int16x8_t
vxl3 = vreinterpretq_s16_u16(vsubl_u8(vl3, vlhs_zero_point));
uint8x8_t vl3n = vld1_u8(lhs_ptr3 + 8);
int16x8_t
vxl3n = vreinterpretq_s16_u16(vsubl_u8(vl3n, vlhs_zero_point));
vo3 = vmlal_s16(vo3, vget_low_s16(vxl3), vget_low_s16(vxr0));
vo3 = vmlal_high_s16(vo3, vxl3, vxr0);
vo3 = vmlal_s16(vo3, vget_low_s16(vxl3n), vget_low_s16(vxr0n));
vo3 = vmlal_high_s16(vo3, vxl3n, vxr0n);
lhs_ptr += 16;
rhs_ptr += 16;
}
#endif // __aarch64__
int32x4_t vo = {vaddvq_s32(vo0),
vaddvq_s32(vo1),
vaddvq_s32(vo2),
vaddvq_s32(vo3)};
for (index_t w = 0; w < w_remain; ++w) {
vo[0] +=
(lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point);
vo[1] += (lhs_ptr[lhs_width] - lhs_zero_point)
* (rhs_ptr[0] - rhs_zero_point);
vo[2] += (lhs_ptr[lhs_width * 2] - lhs_zero_point)
* (rhs_ptr[0] - rhs_zero_point);
vo[3] += (lhs_ptr[lhs_width * 3] - lhs_zero_point)
* (rhs_ptr[0] - rhs_zero_point);
++lhs_ptr;
++rhs_ptr;
}
int32x4_t vbias = vdupq_n_s32(0);
if (bias) {
vbias = vld1q_s32(bias_data + h_offset);
}
vo = vaddq_s32(vo, vbias);
if (is_output_type_uint8) {
int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier);
int32x4_t
fixup = vshrq_n_s32(vandq_s32(vo_mul, voutput_shift_left), 31);
int32x4_t fixed_up_x = vqaddq_s32(vo_mul, fixup);
int32x4_t
vo_rescale_int32 = vrshlq_s32(fixed_up_x, voutput_shift_left);
int16x4_t vo_rescale_int16 = vqmovn_s32(vo_rescale_int32);
uint8x8_t vo_rescale_uint8 =
vqmovun_s16(vcombine_s16(vo_rescale_int16, vo_rescale_int16));
ret_ptr[0] = vo_rescale_uint8[0];
ret_ptr[1] = vo_rescale_uint8[1];
ret_ptr[2] = vo_rescale_uint8[2];
ret_ptr[3] = vo_rescale_uint8[3];
} else {
ret_ptr[0] = vo[0];
ret_ptr[1] = vo[1];
ret_ptr[2] = vo[2];
ret_ptr[3] = vo[3];
}
} else { // h_block_len < 4
// TODO(liyin): handle here case by case (1,2,3) to accelerate
const uint8_t *tmp_lhs_ptr = lhs_ptr;
const uint8_t *tmp_rhs_ptr = rhs_ptr;
for (index_t h = 0; h < h_block_len; ++h) {
lhs_ptr = tmp_lhs_ptr + h * lhs_width;
rhs_ptr = tmp_rhs_ptr;
int32x4_t vo0 = vdupq_n_s32(0);
for (index_t w = 0; w < w_block_count; ++w) {
uint8x8_t vr0 = vld1_u8(rhs_ptr);
int16x8_t
vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point));
uint8x8_t vr0n = vld1_u8(rhs_ptr + 8);
int16x8_t
vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point));
uint8x8_t vl0 = vld1_u8(lhs_ptr);
int16x8_t
vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point));
uint8x8_t vl0n = vld1_u8(lhs_ptr + 8);
int16x8_t
vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point));
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0));
vo0 = vmlal_high_s16(vo0, vxl0, vxr0);
vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n));
vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n);
lhs_ptr += 16;
rhs_ptr += 16;
} // w
int32_t s0 = vaddvq_s32(vo0) + (bias ? bias_data[h_offset + h] : 0);
for (index_t w = 0; w < w_remain; ++w) {
s0 += (lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point);
++lhs_ptr;
++rhs_ptr;
} // w
if (is_output_type_uint8) {
ret_ptr[h] =
Saturate<uint8_t>(std::roundf(s0 * output_multiplier_float));
} else {
ret_ptr[h] = s0;
}
} // h
} // if
} // h_block_idx
} // b
return MaceStatus::MACE_SUCCESS;
}
template
class Gemv<uint8_t>;
template
class Gemv<int32_t>;
} // namespace q8
} // namespace arm
} // namespace ops
} // namespace mace
#if defined(vmlal_high_s16)
#undef vmlal_high_s16
#undef vaddvq_s32
#endif
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_Q8_GEMV_H_
#define MACE_OPS_ARM_Q8_GEMV_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
namespace mace {
namespace ops {
namespace arm {
namespace q8 {
template<typename OUTPUT_TYPE>
class Gemv {
public:
Gemv() {}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output);
};
} // namespace q8
} // namespace arm
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARM_Q8_GEMV_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
#include "mace/ops/arm/q8/gemv.h"
#include "mace/ops/ref/gemv.h"
#include "mace/ops/testing/test_utils.h"
namespace mace {
namespace ops {
namespace test {
void TestGemvInt32(const index_t batch,
const index_t height,
const index_t width,
const bool lhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor bias(GetCPUAllocator(), DataType::DT_INT32);
Tensor output(GetCPUAllocator(), DataType::DT_INT32);
lhs.SetScale(0.5);
rhs.SetScale(0.3);
lhs.SetZeroPoint(0);
rhs.SetZeroPoint(0);
lhs.Resize({lhs_batched ? batch : 1, height, width});
rhs.Resize({batch, width});
bias.Resize({height});
output.Resize({batch, height});
{
Tensor::MappingGuard lhs_guard(&lhs);
Tensor::MappingGuard rhs_guard(&rhs);
Tensor::MappingGuard bias_guard(&bias);
uint8_t *lhs_data = lhs.mutable_data<uint8_t>();
uint8_t *rhs_data = rhs.mutable_data<uint8_t>();
int32_t *bias_data = bias.mutable_data<int32_t>();
GenerateRandomIntTypeData<uint8_t>(lhs.shape(), lhs_data);
GenerateRandomIntTypeData<uint8_t>(rhs.shape(), rhs_data);
GenerateRandomIntTypeData<int32_t>(bias.shape(), bias_data);
}
mace::ops::arm::q8::Gemv<int32_t> gemv;
gemv.Compute(nullptr,
&lhs,
&rhs,
&bias,
batch,
height,
width,
lhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
expected_output.Resize({batch, height});
mace::ops::ref::Gemv<int32_t> gemv_ref;
gemv_ref.Compute(nullptr,
&lhs,
&rhs,
&bias,
batch,
height,
width,
lhs_batched,
&expected_output);
Tensor::MappingGuard output_guard(&output);
Tensor::MappingGuard expected_guard(&expected_output);
const int32_t *output_data = output.data<int32_t>();
const int32_t *expected_data = expected_output.data<int32_t>();
for (index_t i = 0; i < output.size(); ++i) {
EXPECT_EQ(expected_data[i], output_data[i]);
}
}
void TestGemvUint8(const index_t batch,
const index_t height,
const index_t width,
const bool lhs_batched) {
Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8);
Tensor bias(GetCPUAllocator(), DataType::DT_INT32);
Tensor output(GetCPUAllocator(), DataType::DT_UINT8);
lhs.SetScale(0.5);
rhs.SetScale(0.3);
output.SetScale(0.6);
lhs.SetZeroPoint(23);
rhs.SetZeroPoint(45);
output.SetZeroPoint(57);
lhs.Resize({batch, height, width});
rhs.Resize({batch, width});
bias.Resize({height});
output.Resize({batch, height});
{
Tensor::MappingGuard lhs_guard(&lhs);
Tensor::MappingGuard rhs_guard(&rhs);
Tensor::MappingGuard bias_guard(&bias);
uint8_t *lhs_data = lhs.mutable_data<uint8_t>();
uint8_t *rhs_data = rhs.mutable_data<uint8_t>();
int32_t *bias_data = bias.mutable_data<int32_t>();
GenerateRandomIntTypeData<uint8_t>(lhs.shape(), lhs_data);
GenerateRandomIntTypeData<uint8_t>(rhs.shape(), rhs_data);
GenerateRandomIntTypeData<int32_t>(bias.shape(), bias_data);
}
mace::ops::arm::q8::Gemv<uint8_t> gemv;
gemv.Compute(nullptr,
&lhs,
&rhs,
&bias,
batch,
height,
width,
lhs_batched,
&output);
Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32);
expected_output.SetScale(0.6);
expected_output.SetZeroPoint(57);
expected_output.Resize({batch, height});
mace::ops::ref::Gemv<uint8_t> gemv_ref;
gemv_ref.Compute(nullptr,
&lhs,
&rhs,
&bias,
batch,
height,
width,
lhs_batched,
&expected_output);
Tensor::MappingGuard output_guard(&output);
Tensor::MappingGuard expected_guard(&expected_output);
const uint8_t *output_data = output.data<uint8_t>();
const uint8_t *expected_data = expected_output.data<uint8_t>();
for (index_t i = 0; i < output.size(); ++i) {
EXPECT_EQ(expected_data[i], output_data[i]);
}
}
TEST(ArmGemv, TestGemvInt32) {
TestGemvInt32(1, 16, 4, true);
TestGemvInt32(1, 16, 256, true);
TestGemvInt32(2, 16, 256, true);
TestGemvInt32(3, 63, 257, true);
TestGemvInt32(1, 16, 4, false);
TestGemvInt32(1, 16, 256, false);
TestGemvInt32(2, 16, 256, false);
TestGemvInt32(3, 63, 257, false);
}
TEST(ArmGemv, TestGemvUint8) {
TestGemvUint8(1, 16, 4, true);
TestGemvUint8(1, 16, 256, true);
TestGemvUint8(2, 16, 256, true);
TestGemvUint8(3, 63, 257, true);
TestGemvUint8(1, 16, 4, false);
TestGemvUint8(1, 16, 256, false);
TestGemvUint8(2, 16, 256, false);
TestGemvUint8(3, 63, 257, false);
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_COMMON_ACTIVATION_TYPE_H_
#define MACE_OPS_COMMON_ACTIVATION_TYPE_H_
namespace mace {
namespace ops {
enum ActivationType {
NOOP = 0,
RELU = 1,
RELUX = 2,
PRELU = 3,
TANH = 4,
SIGMOID = 5,
LEAKYRELU = 6,
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_COMMON_ACTIVATION_TYPE_H_
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include <algorithm>
#include <cmath>
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_CONV_POOL_2D_UTIL_H_
#define MACE_OPS_CONV_POOL_2D_UTIL_H_
#ifndef MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_
#define MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_
#include "mace/core/tensor.h"
......@@ -116,4 +116,4 @@ MaceStatus ConstructNHWCInputWithPadding(const Tensor *input,
} // namespace ops
} // namespace mace
#endif // MACE_OPS_CONV_POOL_2D_UTIL_H_
#endif // MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/common/transpose.h"
#include <algorithm>
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
#include "mace/utils/logging.h"
namespace mace {
namespace ops {
namespace {
void TransposeNHWCToNCHWC3(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width * 3;
index_t out_offset = h * width;
#if defined(MACE_ENABLE_NEON)
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4x3_t vi = vld3q_f32(input + in_offset);
vst1q_f32(output + out_offset, vi.val[0]);
vst1q_f32(output + out_offset + image_size, vi.val[1]);
vst1q_f32(output + out_offset + image_size * 2, vi.val[2]);
in_offset += 12;
out_offset += 4;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 3; ++c) {
output[h * width + image_size * c + w] =
input[h * width * 3 + w * 3 + c];
}
}
#else
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < 3; ++c) {
output[out_offset + c * image_size + w] = input[in_offset + w * 3 + c];
}
}
#endif
}
}
void TransposeNCHWToNHWCC2(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width;
index_t out_offset = h * width * 2;
#if defined(MACE_ENABLE_NEON)
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4_t vi0 = vld1q_f32(input + in_offset);
float32x4_t vi1 = vld1q_f32(input + in_offset + image_size);
float32x4x2_t vi = {vi0, vi1};
vst2q_f32(output + out_offset, vi);
in_offset += 4;
out_offset += 8;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 2; ++c) {
output[h * width * 2 + w * 2 + c] =
input[h * width + image_size * c + w];
}
}
#else
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < 2; ++c) {
output[out_offset + w * 2 + c] = input[in_offset + c * image_size + w];
}
}
#endif
}
}
} // namespace
MaceStatus Transpose(const float *input,
const std::vector<int64_t> &input_shape,
const std::vector<int> &dst_dims,
float *output) {
MACE_CHECK((input_shape.size() == 2 && dst_dims.size() == 2) ||
(input_shape.size() == 4 && dst_dims.size() == 4),
"Only support 2D or 4D transpose");
std::vector<index_t> output_shape;
for (size_t i = 0; i < dst_dims.size(); ++i) {
output_shape.push_back(input_shape[dst_dims[i]]);
}
if (input_shape.size() == 2) {
MACE_CHECK(dst_dims[0] == 1 && dst_dims[1] == 0, "no need transform");
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = height > 512 || width > 512 ? 64 : 32;
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
output[tile_j * stride_i + tile_i] =
input[tile_i * stride_j + tile_j];
}
}
}
}
} else if (input_shape.size() == 4) {
std::vector<int> transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2};
std::vector<int> transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1};
index_t batch_size = input_shape[1] * input_shape[2] * input_shape[3];
if (dst_dims == transpose_order_from_NHWC_to_NCHW && input_shape[3] == 3) {
for (index_t b = 0; b < input_shape[0]; ++b) {
TransposeNHWCToNCHWC3(input + b * batch_size,
output + b * batch_size,
input_shape[1],
input_shape[2]);
}
} else if (dst_dims == transpose_order_from_NCHW_to_NHWC
&& input_shape[1] == 2) {
for (index_t b = 0; b < input_shape[0]; ++b) {
TransposeNCHWToNHWCC2(input + b * batch_size,
output + b * batch_size,
input_shape[2],
input_shape[3]);
}
} else if (dst_dims == std::vector<int>{0, 2, 1, 3}) {
index_t height = input_shape[1];
index_t width = input_shape[2];
index_t channel = input_shape[3];
index_t channel_raw_size = channel * sizeof(float);
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = std::max(static_cast<index_t>(1),
static_cast<index_t>(std::sqrt(
8 * 1024 / channel)));
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
memcpy(output + (tile_j * stride_i + tile_i) * channel,
input + (tile_i * stride_j + tile_j) * channel,
channel_raw_size);
}
}
}
}
} else {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dst_dims[0]] = odim[0];
idim[dst_dims[1]] = odim[1];
idim[dst_dims[2]] = odim[2];
idim[dst_dims[3]] = odim[3];
output[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
}
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
return MaceStatus::MACE_SUCCESS;
}
} // namespace ops
} // namespace mace
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_TRANSPOSE_H_
#define MACE_OPS_TRANSPOSE_H_
#ifndef MACE_OPS_COMMON_TRANSPOSE_H_
#define MACE_OPS_COMMON_TRANSPOSE_H_
#include <vector>
......@@ -30,4 +30,4 @@ MaceStatus Transpose(const float *input,
} // namespace ops
} // namespace mace
#endif // MACE_OPS_TRANSPOSE_H_
#endif // MACE_OPS_COMMON_TRANSPOSE_H_
......@@ -30,7 +30,7 @@
#include "mace/ops/arm/conv_2d_neon.h"
#include "mace/ops/arm/conv_winograd.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/utils/utils.h"
#ifdef MACE_ENABLE_QUANTIZE
......
......@@ -16,7 +16,7 @@
#include "mace/benchmark/statistics.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
......@@ -15,7 +15,7 @@
#include <fstream>
#include <vector>
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "mace/core/operator.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
......
......@@ -22,7 +22,7 @@
#include "mace/core/operator.h"
#include "mace/core/types.h"
#include "mace/ops/activation.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
......
......@@ -16,7 +16,7 @@
#include "mace/benchmark/statistics.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
......@@ -16,7 +16,7 @@
#include <vector>
#include "mace/ops/deconv_2d.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
......@@ -16,7 +16,7 @@
#include "mace/benchmark/statistics.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
......@@ -20,13 +20,19 @@
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
#include "mace/ops/activation.h"
#include "mace/ops/gemm.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/gemv.h"
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/quantization_util.h"
#include "mace/ops/arm/q8/gemv.h"
#endif // MACE_ENABLE_QUANTIZE
#else
#include "mace/ops/ref/gemv.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/buffer_transformer.h"
#include "mace/ops/opencl/image/fully_connected.h"
......@@ -41,10 +47,10 @@ class FullyConnectedOpBase : public Operation {
: Operation(context),
activation_(ops::StringToActivationType(
Operation::GetOptionalArg<std::string>("activation",
"NOOP"))),
"NOOP"))),
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
leakyrelu_coefficient_(Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f)) {}
"leakyrelu_coefficient", 0.0f)) {}
protected:
const ActivationType activation_;
const float relux_max_limit_;
......@@ -54,10 +60,10 @@ class FullyConnectedOpBase : public Operation {
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
template <DeviceType D, class T>
template<DeviceType D, class T>
class FullyConnectedOp;
template <>
template<>
class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
public:
explicit FullyConnectedOp(OpConstructContext *context)
......@@ -84,38 +90,37 @@ class FullyConnectedOp<DeviceType::CPU, float> : public FullyConnectedOpBase {
}
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
const index_t N = output->dim(0);
const index_t batch = output->dim(0);
const index_t input_size = weight->dim(1) * weight->dim(2) * weight->dim(3);
const index_t output_size = weight->dim(0);
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_weight(weight);
gemv_.Compute(context,
weight,
input,
bias,
batch,
output_size,
input_size,
false,
output);
Tensor::MappingGuard guard_output(output);
const float *input_ptr = input->data<float>();
const float *weight_ptr = weight->data<float>();
float *output_ptr = output->mutable_data<float>();
Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr);
if (bias) {
Tensor::MappingGuard guard_bias(bias);
const float *bias_ptr = bias == nullptr ? nullptr : bias->data<float>();
for (int i = 0; i < N; ++i) {
for (int j = 0; j < output_size; ++j) {
output_ptr[j + i * output_size] += bias_ptr[j];
}
}
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_, leakyrelu_coefficient_);
return MaceStatus::MACE_SUCCESS;
}
private:
#ifdef MACE_ENABLE_NEON
arm::fp32::Gemv gemv_;
#else
ref::Gemv<float> gemv_;
#endif // MACE_ENABLE_NEON
};
#ifdef MACE_ENABLE_QUANTIZE
template <>
template<>
class FullyConnectedOp<DeviceType::CPU, uint8_t>
: public FullyConnectedOpBase {
public:
......@@ -145,44 +150,28 @@ class FullyConnectedOp<DeviceType::CPU, uint8_t>
std::vector<index_t> output_shape = {input->dim(0), 1, 1, weight->dim(0)};
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
const int N = static_cast<int>(output->dim(0));
const int batch = static_cast<int>(output->dim(0));
const int input_size =
static_cast<int>(weight->dim(1) * weight->dim(2) * weight->dim(3));
const int output_size = static_cast<int>(weight->dim(0));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_weight(weight);
Tensor::MappingGuard guard_output(output);
auto input_ptr = input->data<uint8_t>();
auto weight_ptr = weight->data<uint8_t>();
auto output_ptr = output->mutable_data<uint8_t>();
auto bias_ptr = GetBiasData(bias,
input->scale(),
weight->scale(),
output_size,
&bias_);
gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::RowMajor>
weight_matrix(weight_ptr, output_size, input_size);
gemmlowp::MatrixMap<const uint8_t, gemmlowp::MapOrder::ColMajor>
input_matrix(input_ptr, input_size, N);
gemmlowp::MatrixMap<uint8_t, gemmlowp::MapOrder::ColMajor>
output_matrix(output_ptr, output_size, N);
const auto &output_pipeline = GemmlowpOutputPipeline::Make(
bias_ptr, output_size, weight->scale(), input->scale(), output->scale(),
output->zero_point());
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, BitDepthParams>(
gemm_context, weight_matrix, input_matrix, &output_matrix,
-weight->zero_point(), -input->zero_point(), output_pipeline);
gemv_.Compute(context,
weight,
input,
bias,
batch,
output_size,
input_size,
false,
output);
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int32_t> bias_;
#ifdef MACE_ENABLE_NEON
::mace::ops::arm::q8::Gemv<uint8_t> gemv_;
#else
ref::Gemv<uint8_t> gemv_;
#endif // MACE_ENABLE_NEON
};
#endif // MACE_ENABLE_QUANTIZE
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cstring>
#include <vector>
#include "mace/core/tensor.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/ops/gemm.h"
/**
* Gemm does fast matrix multiplications with batch.
* It is optimized for arm64-v8 and armeabi-v7a using neon.
*
* We adopt two-level tiling to make better use of l1 cache and register.
* For register tiling, function like GemmXYZ computes gemm for
* matrix[X, Y] * matrix[Y, Z] with all data being able to fit in register.
* For cache tiling, we try to compute one block of multiplication with
* two input matrices and one output matrix fit in l1 cache.
*/
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace mace {
namespace ops {
namespace {
inline void GemmBlock(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *C) {
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
for (int k = 0; k < K; ++k) {
C[i * stride_c + j] += A[i * stride_a + k] * B[k * stride_b + j];
}
}
}
}
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \
c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else
#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1); \
c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \
c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1);
#endif
#endif
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3);
#else
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);
#endif
#endif
inline void Gemm144(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED(stride_a);
MACE_UNUSED(stride_c);
float32x4_t a0;
float32x4_t b0, b1, b2, b3;
float32x4_t c0;
a0 = vld1q_f32(a_ptr);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
MACE_GEMM_PART_CAL_4(0);
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm244(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
#else
GemmBlock(a_ptr, b_ptr, 2, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm344(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
#else
GemmBlock(a_ptr, b_ptr, 3, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm444(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
#else
GemmBlock(a_ptr, b_ptr, 4, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm544(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3, c4;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
MACE_GEMM_PART_CAL_4(4);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
#else
GemmBlock(a_ptr, b_ptr, 5, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm644(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3, c4, c5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
a5 = vld1q_f32(a_ptr + 5 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
MACE_GEMM_PART_CAL_4(4);
MACE_GEMM_PART_CAL_4(5);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
#else
GemmBlock(a_ptr, b_ptr, 6, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm884(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14,
a15;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4, c5, c6, c7;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_a);
a15 = vld1q_f32(a_ptr + 7 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c);
c7 = vld1q_f32(c_ptr + 7 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL_8(6, 12, 13);
MACE_GEMM_PART_CAL_8(7, 14, 15);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
vst1q_f32(c_ptr + 6 * stride_c, c6);
vst1q_f32(c_ptr + 7 * stride_c, c7);
#else
GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm184(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED(stride_a);
MACE_UNUSED(stride_c);
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
MACE_GEMM_PART_CAL_8(0, 0, 1);
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm284(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
#else
GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm384(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
#else
GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm484(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
#else
GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm584(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
#else
GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm684(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4, c5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
#else
GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm784(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13;
float32x4_t b0, b1, b2, b3, b4, b5, b6, b7;
float32x4_t c0, c1, c2, c3, c4, c5, c6;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
b4 = vld1q_f32(b_ptr + 4 * stride_b);
b5 = vld1q_f32(b_ptr + 5 * stride_b);
b6 = vld1q_f32(b_ptr + 6 * stride_b);
b7 = vld1q_f32(b_ptr + 7 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c);
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL_8(6, 12, 13);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
vst1q_f32(c_ptr + 6 * stride_c, c6);
#else
GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void GemmTile(const float *A,
const float *B,
const index_t height,
const index_t K,
const index_t width,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *C) {
#if defined(MACE_ENABLE_NEON)
index_t h = 0;
index_t w = 0;
index_t k = 0;
#if defined(__aarch64__)
int reg_height_tile = 8;
int reg_K_tile = 8;
#else
int reg_height_tile = 6;
int reg_K_tile = 4;
#endif
for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k);
#if defined(__aarch64__) && defined(__clang__)
int nw = width >> 2;
if (nw > 0) {
// load A
float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13,
a14, a15;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 4);
a2 = vld1q_f32(a_ptr + 1 * stride_a);
a3 = vld1q_f32(a_ptr + 1 * stride_a + 4);
a4 = vld1q_f32(a_ptr + 2 * stride_a);
a5 = vld1q_f32(a_ptr + 2 * stride_a + 4);
a6 = vld1q_f32(a_ptr + 3 * stride_a);
a7 = vld1q_f32(a_ptr + 3 * stride_a + 4);
a8 = vld1q_f32(a_ptr + 4 * stride_a);
a9 = vld1q_f32(a_ptr + 4 * stride_a + 4);
a10 = vld1q_f32(a_ptr + 5 * stride_a);
a11 = vld1q_f32(a_ptr + 5 * stride_a + 4);
a12 = vld1q_f32(a_ptr + 6 * stride_a);
a13 = vld1q_f32(a_ptr + 6 * stride_a + 4);
a14 = vld1q_f32(a_ptr + 7 * stride_a);
a15 = vld1q_f32(a_ptr + 7 * stride_a + 4);
const float *b_ptr0 = B + k * stride_b;
const float *b_ptr1 = B + (k + 1) * stride_b;
const float *b_ptr2 = B + (k + 2) * stride_b;
const float *b_ptr3 = B + (k + 3) * stride_b;
const float *b_ptr4 = B + (k + 4) * stride_b;
const float *b_ptr5 = B + (k + 5) * stride_b;
const float *b_ptr6 = B + (k + 6) * stride_b;
const float *b_ptr7 = B + (k + 7) * stride_b;
float *c_ptr0 = C + h * stride_c;
float *c_ptr1 = C + (h + 1) * stride_c;
float *c_ptr2 = C + (h + 2) * stride_c;
float *c_ptr3 = C + (h + 3) * stride_c;
float *c_ptr4 = C + (h + 4) * stride_c;
float *c_ptr5 = C + (h + 5) * stride_c;
float *c_ptr6 = C + (h + 6) * stride_c;
float *c_ptr7 = C + (h + 7) * stride_c;
asm volatile(
"0: \n"
"prfm pldl1keep, [%9, #128] \n"
"ld1 {v16.4s}, [%9], #16 \n"
"prfm pldl1keep, [%1, #128] \n"
"ld1 {v18.4s}, [%1] \n"
"prfm pldl1keep, [%2, #128] \n"
"ld1 {v19.4s}, [%2] \n"
"prfm pldl1keep, [%3, #128] \n"
"ld1 {v20.4s}, [%3] \n"
"prfm pldl1keep, [%4, #128] \n"
"ld1 {v21.4s}, [%4] \n"
"prfm pldl1keep, [%5, #128] \n"
"ld1 {v22.4s}, [%5] \n"
"prfm pldl1keep, [%6, #128] \n"
"ld1 {v23.4s}, [%6] \n"
"prfm pldl1keep, [%7, #128] \n"
"ld1 {v24.4s}, [%7] \n"
"prfm pldl1keep, [%8, #128] \n"
"ld1 {v25.4s}, [%8] \n"
"prfm pldl1keep, [%10, #128] \n"
"ld1 {v17.4s}, [%10], #16 \n"
"fmla v18.4s, v16.4s, %34.s[0] \n"
"fmla v19.4s, v16.4s, %35.s[0] \n"
"fmla v20.4s, v16.4s, %36.s[0] \n"
"fmla v21.4s, v16.4s, %37.s[0] \n"
"fmla v22.4s, v16.4s, %38.s[0] \n"
"fmla v23.4s, v16.4s, %39.s[0] \n"
"fmla v24.4s, v16.4s, %40.s[0] \n"
"fmla v25.4s, v16.4s, %41.s[0] \n"
"fmla v18.4s, v17.4s, %34.s[1] \n"
"fmla v19.4s, v17.4s, %35.s[1] \n"
"fmla v20.4s, v17.4s, %36.s[1] \n"
"fmla v21.4s, v17.4s, %37.s[1] \n"
"prfm pldl1keep, [%11, #128] \n"
"ld1 {v16.4s}, [%11], #16 \n"
"fmla v22.4s, v17.4s, %38.s[1] \n"
"fmla v23.4s, v17.4s, %39.s[1] \n"
"fmla v24.4s, v17.4s, %40.s[1] \n"
"fmla v25.4s, v17.4s, %41.s[1] \n"
"fmla v18.4s, v16.4s, %34.s[2] \n"
"fmla v19.4s, v16.4s, %35.s[2] \n"
"fmla v20.4s, v16.4s, %36.s[2] \n"
"fmla v21.4s, v16.4s, %37.s[2] \n"
"prfm pldl1keep, [%12, #128] \n"
"ld1 {v17.4s}, [%12], #16 \n"
"fmla v22.4s, v16.4s, %38.s[2] \n"
"fmla v23.4s, v16.4s, %39.s[2] \n"
"fmla v24.4s, v16.4s, %40.s[2] \n"
"fmla v25.4s, v16.4s, %41.s[2] \n"
"fmla v18.4s, v17.4s, %34.s[3] \n"
"fmla v19.4s, v17.4s, %35.s[3] \n"
"fmla v20.4s, v17.4s, %36.s[3] \n"
"fmla v21.4s, v17.4s, %37.s[3] \n"
"prfm pldl1keep, [%13, #128] \n"
"ld1 {v16.4s}, [%13], #16 \n"
"fmla v22.4s, v17.4s, %38.s[3] \n"
"fmla v23.4s, v17.4s, %39.s[3] \n"
"fmla v24.4s, v17.4s, %40.s[3] \n"
"fmla v25.4s, v17.4s, %41.s[3] \n"
"fmla v18.4s, v16.4s, %42.s[0] \n"
"fmla v19.4s, v16.4s, %43.s[0] \n"
"fmla v20.4s, v16.4s, %44.s[0] \n"
"fmla v21.4s, v16.4s, %45.s[0] \n"
"prfm pldl1keep, [%14, #128] \n"
"ld1 {v17.4s}, [%14], #16 \n"
"fmla v22.4s, v16.4s, %46.s[0] \n"
"fmla v23.4s, v16.4s, %47.s[0] \n"
"fmla v24.4s, v16.4s, %48.s[0] \n"
"fmla v25.4s, v16.4s, %49.s[0] \n"
"fmla v18.4s, v17.4s, %42.s[1] \n"
"fmla v19.4s, v17.4s, %43.s[1] \n"
"fmla v20.4s, v17.4s, %44.s[1] \n"
"fmla v21.4s, v17.4s, %45.s[1] \n"
"prfm pldl1keep, [%15, #128] \n"
"ld1 {v16.4s}, [%15], #16 \n"
"fmla v22.4s, v17.4s, %46.s[1] \n"
"fmla v23.4s, v17.4s, %47.s[1] \n"
"fmla v24.4s, v17.4s, %48.s[1] \n"
"fmla v25.4s, v17.4s, %49.s[1] \n"
"fmla v18.4s, v16.4s, %42.s[2] \n"
"fmla v19.4s, v16.4s, %43.s[2] \n"
"fmla v20.4s, v16.4s, %44.s[2] \n"
"fmla v21.4s, v16.4s, %45.s[2] \n"
"prfm pldl1keep, [%16, #128] \n"
"ld1 {v17.4s}, [%16], #16 \n"
"fmla v22.4s, v16.4s, %46.s[2] \n"
"fmla v23.4s, v16.4s, %47.s[2] \n"
"fmla v24.4s, v16.4s, %48.s[2] \n"
"fmla v25.4s, v16.4s, %49.s[2] \n"
"fmla v18.4s, v17.4s, %42.s[3] \n"
"fmla v19.4s, v17.4s, %43.s[3] \n"
"fmla v20.4s, v17.4s, %44.s[3] \n"
"fmla v21.4s, v17.4s, %45.s[3] \n"
"st1 {v18.4s}, [%1], #16 \n"
"st1 {v19.4s}, [%2], #16 \n"
"st1 {v20.4s}, [%3], #16 \n"
"st1 {v21.4s}, [%4], #16 \n"
"fmla v22.4s, v17.4s, %46.s[3] \n"
"fmla v23.4s, v17.4s, %47.s[3] \n"
"fmla v24.4s, v17.4s, %48.s[3] \n"
"fmla v25.4s, v17.4s, %49.s[3] \n"
"subs %w0, %w0, #1 \n"
"st1 {v22.4s}, [%5], #16 \n"
"st1 {v23.4s}, [%6], #16 \n"
"st1 {v24.4s}, [%7], #16 \n"
"st1 {v25.4s}, [%8], #16 \n"
"bne 0b \n"
: "=r"(nw), // 0
"=r"(c_ptr0), // 1
"=r"(c_ptr1), // 2
"=r"(c_ptr2), // 3
"=r"(c_ptr3), // 4
"=r"(c_ptr4), // 5
"=r"(c_ptr5), // 6
"=r"(c_ptr6), // 7
"=r"(c_ptr7), // 8
"=r"(b_ptr0), // 9
"=r"(b_ptr1), // 10
"=r"(b_ptr2), // 11
"=r"(b_ptr3), // 12
"=r"(b_ptr4), // 13
"=r"(b_ptr5), // 14
"=r"(b_ptr6), // 15
"=r"(b_ptr7) // 16
: "0"(nw), // 17
"1"(c_ptr0), // 18
"2"(c_ptr1), // 19
"3"(c_ptr2), // 20
"4"(c_ptr3), // 21
"5"(c_ptr4), // 22
"6"(c_ptr5), // 23
"7"(c_ptr6), // 24
"8"(c_ptr7), // 25
"9"(b_ptr0), // 26
"10"(b_ptr1), // 27
"11"(b_ptr2), // 28
"12"(b_ptr3), // 29
"13"(b_ptr4), // 30
"14"(b_ptr5), // 31
"15"(b_ptr6), // 32
"16"(b_ptr7), // 33
"w"(a0), // 34
"w"(a2), // 35
"w"(a4), // 36
"w"(a6), // 37
"w"(a8), // 38
"w"(a10), // 39
"w"(a12), // 40
"w"(a14), // 41
"w"(a1), // 42
"w"(a3), // 43
"w"(a5), // 44
"w"(a7), // 45
"w"(a9), // 46
"w"(a11), // 47
"w"(a13), // 48
"w"(a15) // 49
: "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22",
"v23", "v24", "v25");
w = (width >> 2) << 2;
}
#elif defined(__aarch64__) // gcc
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
}
#else // armv7
int nw = width >> 2;
if (nw > 0) {
float32x4_t a0, a1, a2, a3, a4, a5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
a5 = vld1q_f32(a_ptr + 5 * stride_a);
const float *b_ptr0 = B + k * stride_b;
const float *b_ptr1 = B + (k + 1) * stride_b;
const float *b_ptr2 = B + (k + 2) * stride_b;
const float *b_ptr3 = B + (k + 3) * stride_b;
float *c_ptr0 = C + h * stride_c;
float *c_ptr1 = C + (h + 1) * stride_c;
float *c_ptr2 = C + (h + 2) * stride_c;
float *c_ptr3 = C + (h + 3) * stride_c;
float *c_ptr4 = C + (h + 4) * stride_c;
float *c_ptr5 = C + (h + 5) * stride_c;
asm volatile(
"0: \n"
"pld [%7, #128] \n"
"vld1.f32 {d12-d13}, [%7]! \n"
"pld [%1, #128] \n"
"vld1.f32 {d16-d17}, [%1] \n"
"pld [%2, #128] \n"
"vld1.f32 {d18-d19}, [%2] \n"
"pld [%3, #128] \n"
"vld1.f32 {d20-d21}, [%3] \n"
"pld [%4, #128] \n"
"vld1.f32 {d22-d23}, [%4] \n"
"pld [%5, #128] \n"
"vld1.f32 {d24-d25}, [%5] \n"
"pld [%6, #128] \n"
"vld1.f32 {d26-d27}, [%6] \n"
"pld [%8, #128] \n"
"vld1.f32 {d14-d15}, [%8]! \n"
"vmla.f32 q8, q6, %e22[0] \n"
"vmla.f32 q9, q6, %e23[0] \n"
"vmla.f32 q10, q6, %e24[0] \n"
"vmla.f32 q11, q6, %e25[0] \n"
"vmla.f32 q12, q6, %e26[0] \n"
"vmla.f32 q13, q6, %e27[0] \n"
"pld [%9, #128] \n"
"vld1.f32 {d12-d13}, [%9]! \n"
"vmla.f32 q8, q7, %e22[1] \n"
"vmla.f32 q9, q7, %e23[1] \n"
"vmla.f32 q10, q7, %e24[1] \n"
"vmla.f32 q11, q7, %e25[1] \n"
"vmla.f32 q12, q7, %e26[1] \n"
"vmla.f32 q13, q7, %e27[1] \n"
"pld [%10, #128] \n"
"vld1.f32 {d14-d15}, [%10]! \n"
"vmla.f32 q8, q6, %f22[0] \n"
"vmla.f32 q9, q6, %f23[0] \n"
"vmla.f32 q10, q6, %f24[0] \n"
"vmla.f32 q11, q6, %f25[0] \n"
"vmla.f32 q12, q6, %f26[0] \n"
"vmla.f32 q13, q6, %f27[0] \n"
"vmla.f32 q8, q7, %f22[1] \n"
"vmla.f32 q9, q7, %f23[1] \n"
"vmla.f32 q10, q7, %f24[1] \n"
"vmla.f32 q11, q7, %f25[1] \n"
"vmla.f32 q12, q7, %f26[1] \n"
"vmla.f32 q13, q7, %f27[1] \n"
"vst1.f32 {d16-d17}, [%1]! \n"
"vst1.f32 {d18-d19}, [%2]! \n"
"vst1.f32 {d20-d21}, [%3]! \n"
"vst1.f32 {d22-d23}, [%4]! \n"
"vst1.f32 {d24-d25}, [%5]! \n"
"vst1.f32 {d26-d27}, [%6]! \n"
"subs %0, #1 \n"
"bne 0b \n"
: "=r"(nw), // 0
"=r"(c_ptr0), // 1
"=r"(c_ptr1), // 2
"=r"(c_ptr2), // 3
"=r"(c_ptr3), // 4
"=r"(c_ptr4), // 5
"=r"(c_ptr5), // 6
"=r"(b_ptr0), // 7
"=r"(b_ptr1), // 8
"=r"(b_ptr2), // 9
"=r"(b_ptr3) // 10
: "0"(nw), // 11
"1"(c_ptr0), // 12
"2"(c_ptr1), // 13
"3"(c_ptr2), // 14
"4"(c_ptr3), // 15
"5"(c_ptr4), // 16
"6"(c_ptr5), // 17
"7"(b_ptr0), // 18
"8"(b_ptr1), // 19
"9"(b_ptr2), // 20
"10"(b_ptr3), // 21
"w"(a0), // 22
"w"(a1), // 23
"w"(a2), // 24
"w"(a3), // 25
"w"(a4), // 26
"w"(a5) // 27
: "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12",
"q13", "q14", "q15");
w = (width >> 2) << 2;
}
#endif
if (w < width) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w,
stride_a, stride_b, stride_c, c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b,
stride_c, c_ptr);
}
}
if (h < height) {
index_t remain_h = height - h;
auto gemm_fn = Gemm184;
switch (remain_h) {
case 1:
#if defined(__aarch64__)
gemm_fn = Gemm184;
#else
gemm_fn = Gemm144;
#endif
break;
case 2:
#if defined(__aarch64__)
gemm_fn = Gemm284;
#else
gemm_fn = Gemm244;
#endif
break;
case 3:
#if defined(__aarch64__)
gemm_fn = Gemm384;
#else
gemm_fn = Gemm344;
#endif
break;
case 4:
#if defined(__aarch64__)
gemm_fn = Gemm484;
#else
gemm_fn = Gemm444;
#endif
break;
case 5:
#if defined(__aarch64__)
gemm_fn = Gemm584;
#else
gemm_fn = Gemm544;
#endif
break;
case 6:
#if defined(__aarch64__)
gemm_fn = Gemm684;
#else
LOG(FATAL) << "remain_h should < 6";
#endif
break;
case 7:
#if defined(__aarch64__)
gemm_fn = Gemm784;
#else
LOG(FATAL) << "remain_h should < 6";
#endif
break;
default:
LOG(FATAL) << "remain_h should < 8";
}
for (k = 0; k < K - reg_K_tile; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k);
index_t w;
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
gemm_fn(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
}
if (w < width) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a,
stride_b, stride_c, c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b,
stride_c, c_ptr);
}
}
#else
GemmBlock(A, B, height, K, width, stride_a, stride_b, stride_c, C);
#endif // MACE_ENABLE_NEON
}
} // namespace
void Transpose(const float *src,
index_t height,
index_t width,
index_t stride_w,
float *dst) {
index_t tile_size = height > 512 || width > 512 ? 64 : 32;
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
dst[tile_j * height + tile_i] = src[tile_i * stride_w + tile_j];
}
}
}
}
}
// A: height x K, B: K x width, C: height x width
void Gemm(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C,
const bool transpose_a,
const bool transpose_b) {
if (width == 1 && !transpose_a) {
for (index_t b = 0; b < batch; ++b) {
Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height);
}
return;
}
memset(C, 0, sizeof(float) * batch * height * width);
std::vector<index_t> block_size_dims {height, width, K};
index_t thread_count = MaceOpenMPThreadCount;
MACE_CHECK(thread_count >= 1, "thread should be ge 1");
// TODO(liyin): apply gcd ?
if (height % thread_count == 0) {
block_size_dims[0] = height / thread_count;
} else if (thread_count == 4 && (height & 1) == 0 && (width & 1) == 0) {
block_size_dims[0] = height >> 1;
block_size_dims[1] = width >> 1;
} else if (width % thread_count == 0) {
block_size_dims[1] = width / thread_count;
} else {
if (height >= thread_count) {
block_size_dims[0] = height / thread_count;
} else {
thread_count = std::min(thread_count, height * width);
index_t thread_h = height;
index_t thread_w = RoundUpDiv(thread_count, thread_h);
block_size_dims[0] = 1;
block_size_dims[1] = std::max(static_cast<index_t>(1), width / thread_w);
}
}
const index_t block_tile[3] = {height / block_size_dims[0],
width / block_size_dims[1],
K / block_size_dims[2]};
block_size_dims[0] = height / block_tile[0];
block_size_dims[1] = width / block_tile[1];
block_size_dims[2] = K / block_tile[2];
const index_t remain[3] = {height % block_tile[0],
width % block_tile[1],
K % block_tile[2]};
#pragma omp parallel for collapse(3)
for (index_t n = 0; n < batch; ++n) {
for (index_t bh = 0; bh < block_tile[0]; ++bh) {
for (index_t bw = 0; bw < block_tile[1]; ++bw) {
const index_t remain_height = remain[0];
const index_t remain_width = remain[1];
const index_t remain_k = remain[2];
const index_t block_size_height = block_size_dims[0];
const index_t block_size_width = block_size_dims[1];
const index_t block_size_k = block_size_dims[2];
const index_t this_block_size_height =
block_size_height + (bh < remain_height ? 1 : 0);
const index_t this_block_size_width =
block_size_width + (bw < remain_width ? 1 : 0);
const float *a_base = A + n * height * K;
const float *b_base = B + n * K * width;
float *c_base = C + n * height * width;
const index_t ih_begin =
bh * block_size_height + (bh < remain_height ? bh : remain_height);
const index_t
ih_end = std::min(height, ih_begin + this_block_size_height);
const index_t iw_begin =
bw * block_size_width + (bw < remain_width ? bw : remain_width);
const index_t
iw_end = std::min(width, iw_begin + this_block_size_width);
for (index_t bk = 0; bk < block_tile[2]; ++bk) {
const index_t
this_block_size_k = block_size_k + (bk < remain_k ? 1 : 0);
const index_t
ik_begin = bk * block_size_k + (bk < remain_k ? bk : remain_k);
const index_t ik_end = std::min(K, ik_begin + this_block_size_k);
Tensor trans_a(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor trans_b(GetCPUAllocator(), DataType::DT_FLOAT);
const float *real_a = nullptr;
const float *real_b = nullptr;
float *real_c = c_base + (ih_begin * width + iw_begin);
index_t stride_a;
index_t stride_b;
index_t stride_c = width;
if (transpose_a) {
trans_a.Resize({this_block_size_height, this_block_size_k});
float *trans_a_data = trans_a.mutable_data<float>();
// A[K, H] -> A[H, K]
Transpose(a_base + (ik_begin * height + ih_begin),
ik_end - ik_begin, ih_end - ih_begin, height,
trans_a_data);
real_a = trans_a_data;
stride_a = ik_end - ik_begin;
} else {
real_a = a_base + (ih_begin * K + ik_begin);
stride_a = K;
}
if (transpose_b) {
trans_b.Resize({this_block_size_k, this_block_size_width});
float *trans_b_data = trans_b.mutable_data<float>();
// B[W, K] -> B[K, W]
Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin,
ik_end - ik_begin, K, trans_b_data);
real_b = trans_b_data;
stride_b = iw_end - iw_begin;
} else {
real_b = b_base + (ik_begin * width + iw_begin);
stride_b = width;
}
// inside block:
// calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k
GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin,
iw_end - iw_begin, stride_a, stride_b, stride_c, real_c);
} // bk
} // bw
} // bh
} // n
}
// A: height x K, B: K x width, C: height x width
void GemmRef(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C,
const bool transpose_a,
const bool transpose_b) {
memset(C, 0, sizeof(float) * batch * height * width);
Tensor trans_a(GetCPUAllocator(), DataType::DT_FLOAT);
Tensor trans_b(GetCPUAllocator(), DataType::DT_FLOAT);
float *trans_a_data = nullptr;
float *trans_b_data = nullptr;
if (transpose_a) {
trans_a.Resize({height, K});
trans_a_data = trans_a.mutable_data<float>();
}
if (transpose_b) {
trans_b.Resize({K, width});
trans_b_data = trans_b.mutable_data<float>();
}
for (index_t b = 0; b < batch; ++b) {
const float *real_a = nullptr;
const float *real_b = nullptr;
float *real_c = C + b * height * width;
if (transpose_a) {
// A[K, H] -> A[H, K]
Transpose(A + b * height * K, K, height, height, trans_a_data);
real_a = trans_a_data;
} else {
real_a = A + b * height * K;
}
if (transpose_b) {
// B[W, K] -> B[K, W]
Transpose(B + b * width * K, width, K, K, trans_b_data);
real_b = trans_b_data;
} else {
real_b = B + b * width * K;
}
for (index_t i = 0; i < height; ++i) {
for (index_t j = 0; j < width; ++j) {
for (index_t k = 0; k < K; ++k) {
real_c[i * width + j] += real_a[i * K + k] * real_b[k * width + j];
}
}
}
}
}
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
memset(out_ptr, 0, batch * height * sizeof(float));
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
out_ptr[b * height + h] += v_ptr[b * width + w] * m_ptr[h * width + w];
}
}
}
}
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr) {
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch; ++b) {
for (index_t h = 0; h < height; ++h) {
const float *m_ptr0 = m_ptr + h * width;
const float *v_ptr0 = v_ptr + b * width;
float *out_ptr0 = out_ptr + b * height + h;
float32x4_t vm0, vm1, vm2, vm3;
float32x4_t vv0, vv1, vv2, vv3;
float32x4_t vsum0 = vdupq_n_f32(0.f);
float32x4_t vsum1 = vdupq_n_f32(0.f);
float32x4_t vsum2 = vdupq_n_f32(0.f);
float32x4_t vsum3 = vdupq_n_f32(0.f);
index_t w;
for (w = 0; w + 15 < width; w += 16) {
vm0 = vld1q_f32(m_ptr0);
vv0 = vld1q_f32(v_ptr0);
vm1 = vld1q_f32(m_ptr0 + 4);
vv1 = vld1q_f32(v_ptr0 + 4);
vm2 = vld1q_f32(m_ptr0 + 8);
vv2 = vld1q_f32(v_ptr0 + 8);
vm3 = vld1q_f32(m_ptr0 + 12);
vv3 = vld1q_f32(v_ptr0 + 12);
vsum0 = vmlaq_f32(vsum0, vm0, vv0);
vsum1 = vmlaq_f32(vsum1, vm1, vv1);
vsum2 = vmlaq_f32(vsum2, vm2, vv2);
vsum3 = vmlaq_f32(vsum3, vm3, vv3);
m_ptr0 += 16;
v_ptr0 += 16;
}
for (; w + 7 < width; w += 8) {
vm0 = vld1q_f32(m_ptr0);
vv0 = vld1q_f32(v_ptr0);
vm1 = vld1q_f32(m_ptr0 + 4);
vv1 = vld1q_f32(v_ptr0 + 4);
vsum0 = vmlaq_f32(vsum0, vm0, vv0);
vsum1 = vmlaq_f32(vsum1, vm1, vv1);
m_ptr0 += 8;
v_ptr0 += 8;
}
for (; w + 3 < width; w += 4) {
vm0 = vld1q_f32(m_ptr0);
vv0 = vld1q_f32(v_ptr0);
vsum0 = vmlaq_f32(vsum0, vm0, vv0);
m_ptr0 += 4;
v_ptr0 += 4;
}
vsum0 += vsum1;
vsum2 += vsum3;
vsum0 += vsum2;
float sum0 = vaddvq_f32(vsum0);
// handle remaining w
for (; w < width; ++w) {
sum0 += m_ptr0[0] * v_ptr0[0];
m_ptr0++;
v_ptr0++;
}
*out_ptr0++ = sum0;
} // h
} // b
#else
GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr);
#endif
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_GEMM_H_
#define MACE_OPS_GEMM_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/types.h"
// Gemm function does fast matrix-matrix multiplications with batch.
// Gemv function does fast matrix-vector multiplications with batch.
namespace mace {
namespace ops {
// Gemm calculates A[batch, height, K] dot B[batch, K, width] within each batch,
// and output to C[batch, height, width].
// height, K, width correspond to matrix dimension size after transpose (if any)
void Gemm(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C,
const bool transpose_a = false,
const bool transpose_b = false);
void GemmRef(const float *A,
const float *B,
const index_t batch,
const index_t height,
const index_t K,
const index_t width,
float *C,
const bool transpose_a = false,
const bool transpose_b = false);
// Gemm calculates M[height, width] dot V[batch, height] within each batch of V,
// and output to out[batch, width].
void Gemv(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
void GemvRef(const float *m_ptr,
const float *v_ptr,
const index_t batch,
const index_t width,
const index_t height,
float *out_ptr);
void Transpose(const float *src,
index_t height,
index_t width,
index_t stride_w,
float *dst);
} // namespace ops
} // namespace mace
#endif // MACE_OPS_GEMM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <vector>
#include <memory>
#include <random>
#include "mace/core/types.h"
#include "mace/ops/gemm.h"
#include "mace/ops/sgemm.h"
namespace mace {
namespace {
void GemmTest(index_t batch,
index_t N,
index_t K,
index_t M,
bool transpose_a,
bool transpose_b) {
std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> C(new float[batch * N * M]);
std::unique_ptr<float[]> C_ref(new float[batch * N * M]);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] { return nd(gen); });
std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] { return nd(gen); });
ops::Gemm(A.get(), B.get(), batch, N, K, M, C.get(), transpose_a,
transpose_b);
ops::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a,
transpose_b);
for (int i = 0; i < batch * N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
}
void GemvTest(index_t batch, index_t N, index_t M) {
std::unique_ptr<float[]> A(new float[N * M]);
std::unique_ptr<float[]> B(new float[batch * M]);
std::unique_ptr<float[]> C(new float[batch * N]);
std::unique_ptr<float[]> C_ref(new float[batch * N]);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + N * M, [&gen, &nd] { return nd(gen); });
std::generate(B.get(), B.get() + batch * M, [&gen, &nd] { return nd(gen); });
ops::Gemv(A.get(), B.get(), batch, M, N, C.get());
ops::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get());
for (int i = 0; i < batch * N; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
}
void SGemmTest(index_t batch,
index_t N,
index_t K,
index_t M,
bool transpose_a,
bool transpose_b) {
std::unique_ptr<float[]> A(new float[batch * N * K]);
std::unique_ptr<float[]> B(new float[batch * K * M]);
std::unique_ptr<float[]> C(new float[batch * N * M]);
std::unique_ptr<float[]> C_ref(new float[batch * N * M]);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
std::generate(A.get(), A.get() + batch * N * K,
[&gen, &nd] { return nd(gen); });
std::generate(B.get(), B.get() + batch * K * M,
[&gen, &nd] { return nd(gen); });
ops::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a,
transpose_b);
ops::MatrixMap<const float> matrix_a;
ops::MatrixMap<const float> matrix_b;
if (!transpose_a) {
matrix_a =
ops::MatrixMap<const float>(batch,
N,
K,
ops::RowMajor,
A.get());
} else {
matrix_a =
ops::MatrixMap<const float>(batch,
K,
N,
ops::RowMajor,
A.get());
matrix_a = matrix_a.transpose();
}
if (!transpose_b) {
matrix_b =
ops::MatrixMap<const float>(batch,
K,
M,
ops::RowMajor,
B.get());
} else {
matrix_b =
ops::MatrixMap<const float>(batch,
M,
K,
ops::RowMajor,
B.get());
matrix_b = matrix_b.transpose();
}
ops::MatrixMap<float> matrix_c(batch, N, M, ops::RowMajor, C.get());
ops::SGemm sgemm;
sgemm(matrix_a, matrix_b, &matrix_c);
for (int i = 0; i < N * M; ++i) {
EXPECT_NEAR(C_ref[i], C[i], 0.1);
}
}
} // namespace
TEST(GEMMTest, AlignedWithoutBatch) {
GemmTest(1, 1, 64, 128, false, false);
GemmTest(1, 2, 64, 128, false, true);
GemmTest(1, 3, 64, 128, true, false);
GemmTest(1, 4, 64, 128, true, true);
GemmTest(1, 5, 64, 128, false, false);
GemmTest(1, 6, 64, 128, false, true);
GemmTest(1, 7, 64, 128, true, false);
GemmTest(1, 17, 64, 128, true, true);
GemmTest(1, 256, 128, 4096, false, false);
GemmTest(1, 256, 128, 4104, false, false);
}
TEST(GEMMTest, UnalignedWithoutBatch) {
GemmTest(1, 1, 63, 127, false, false);
GemmTest(1, 2, 63, 127, false, true);
GemmTest(1, 3, 63, 127, true, false);
GemmTest(1, 4, 63, 127, true, true);
GemmTest(1, 5, 63, 127, false, false);
GemmTest(1, 6, 63, 127, false, true);
GemmTest(1, 7, 63, 127, true, false);
GemmTest(1, 17, 63, 127, true, true);
}
TEST(GEMMTest, UnalignedWithBatch) {
GemmTest(3, 1, 63, 127, false, false);
GemmTest(3, 2, 63, 127, false, true);
GemmTest(3, 3, 63, 127, true, false);
GemmTest(3, 4, 63, 127, true, true);
GemmTest(3, 5, 63, 127, false, false);
GemmTest(3, 6, 63, 127, false, true);
GemmTest(3, 7, 63, 127, true, false);
GemmTest(3, 17, 63, 127, true, true);
}
TEST(GEMMTest, gemv) {
GemvTest(1, 17, 63);
GemvTest(3, 17, 63);
}
namespace {
void TestSGemmTranspose(index_t batch, index_t N, index_t K, index_t M) {
SGemmTest(batch, N, K, M, false, false);
SGemmTest(batch, N, K, M, true, false);
SGemmTest(batch, N, K, M, false, true);
SGemmTest(batch, N, K, M, true, true);
}
}
TEST(SGEMMTest, UnalignedWithoutBatch) {
std::vector<index_t> tests{1, 5, 14, 31, 47};
for (index_t N : tests) {
for (index_t K : tests) {
for (index_t M : tests) {
TestSGemmTranspose(1, N, K, M);
TestSGemmTranspose(16, N, K, M);
}
}
}
}
} // namespace mace
......@@ -14,7 +14,7 @@
#include "mace/core/operator.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "mace/ops/ops_test_util.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
namespace ops {
......
......@@ -21,13 +21,23 @@
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
#include "mace/ops/gemm.h"
#include "mace/ops/sgemm.h"
#include "mace/utils/utils.h"
#ifdef MACE_ENABLE_NEON
#include "mace/ops/arm/fp32/gemv.h"
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/arm/q8/gemv.h"
#endif // MACE_ENABLE_QUANTIZE
#else
#include "mace/ops/ref/gemv.h"
#endif // MACE_ENABLE_NEON
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/arm/fixpoint_gemm.h"
#endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL
......@@ -106,7 +116,6 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
}
batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
std::vector<index_t> c_shape = A->shape();
c_shape[rank - 2] = height;
c_shape[rank - 1] = width;
......@@ -141,11 +150,17 @@ class MatMulOp<CPU, float> : public MatMulOpBase {
B->is_weight(),
c_ptr_base,
context->device()->scratch_buffer());
return MaceStatus::MACE_SUCCESS;
}
private:
SGemm sgemm_;
#ifdef MACE_ENABLE_NEON
arm::fp32::Gemv gemv_;
#else
ref::Gemv<float> gemv_;
#endif // MACE_ENABLE_NEON
};
#ifdef MACE_ENABLE_QUANTIZE
......@@ -163,38 +178,54 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
const index_t K,
const index_t width,
Tensor *C) {
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
auto a_ptr_base = A->data<uint8_t>();
auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<uint8_t>();
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
index_t a_size = height * K;
index_t b_size = K * width;
index_t c_size = height * width;
const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias(
A->scale(), B->scale(), C->scale(), C->zero_point());
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, BitDepthParams>(
gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(),
-B->zero_point(), output_pipeline);
#if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C);
} else {
#endif // MACE_ENABLE_NEON
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
auto a_ptr_base = A->data<uint8_t>();
auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<uint8_t>();
auto gemm_context =
context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
index_t a_size = height * K;
index_t b_size = K * width;
index_t c_size = height * width;
const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias(
A->scale(), B->scale(), C->scale(), C->zero_point());
for (index_t i = 0; i < batch; ++i) {
gemmlowp::MatrixMap<const uint8_t, AOrder>
a_matrix(a_ptr_base + i * a_size, height, K);
gemmlowp::MatrixMap<const uint8_t, BOrder>
b_matrix(b_ptr_base + i * b_size, K, width);
gemmlowp::MatrixMap <uint8_t, gemmlowp::MapOrder::RowMajor>
c_matrix(c_ptr_base + i * c_size, height, width);
using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams;
gemmlowp::GemmWithOutputPipeline<uint8_t, uint8_t, BitDepthParams>(
gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(),
-B->zero_point(), output_pipeline);
}
}
#if defined(MACE_ENABLE_NEON)
}
private:
arm::q8::Gemv<uint8_t> gemv_kernel_;
#endif // MACE_ENABLE_NEON
};
template<gemmlowp::MapOrder AOrder, gemmlowp::MapOrder BOrder>
......@@ -207,38 +238,24 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
const index_t K,
const index_t width,
Tensor *C) {
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
auto a_ptr_base = A->data<uint8_t>();
auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<int32_t>();
C->SetScale(A->scale() * B->scale());
C->SetZeroPoint(0);
index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1,
std::multiplies<index_t>());
#if defined(MACE_ENABLE_NEON)
if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) {
// gemv
for (index_t i = 0; i < batch; ++i) {
FixPointGemv(a_ptr_base + i * height * K,
b_ptr_base + i * K,
A->zero_point(),
B->zero_point(),
height,
K,
c_ptr_base + i * height);
}
gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C);
} else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) {
// gevm
for (index_t i = 0; i < batch; ++i) {
FixPointGemv(b_ptr_base + i * K * width,
a_ptr_base + i * K,
B->zero_point(),
A->zero_point(),
width,
K,
c_ptr_base + i * width);
}
gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C);
} else {
#endif // MACE_ENABLE_NEON
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
Tensor::MappingGuard guardc(C);
auto a_ptr_base = A->data<uint8_t>();
auto b_ptr_base = B->data<uint8_t>();
auto c_ptr_base = C->mutable_data<int32_t>();
auto
gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext();
MACE_CHECK_NOTNULL(gemm_context);
......@@ -264,9 +281,12 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
}
}
C->SetScale(A->scale() * B->scale());
C->SetZeroPoint(0);
#if defined(MACE_ENABLE_NEON)
}
private:
arm::q8::Gemv<int32_t> gemv_kernel_;
#endif // MACE_ENABLE_NEON
};
template <>
......
......@@ -21,7 +21,6 @@
#include "public/gemmlowp.h"
#include "mace/benchmark/statistics.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/gemm.h"
#include "mace/ops/sgemm.h"
#include "mace/ops/ops_test_util.h"
......@@ -96,19 +95,6 @@ namespace test {
namespace {
// Matmul with (m, k) x (k, n)
void MatmulBenchmark_Mace(int iters, int m, int k, int n) {
mace::testing::StopTiming();
std::vector<float> lhs(m * k);
std::vector<float> rhs(k * n);
std::vector<float> result(m * n);
// warm up
Gemm(lhs.data(), rhs.data(), 1, m, k, n, result.data());
mace::testing::StartTiming();
while (iters--) {
Gemm(lhs.data(), rhs.data(), 1, m, k, n, result.data());
}
}
void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) {
mace::testing::StopTiming();
std::vector<float> lhs(m * k);
......@@ -234,7 +220,6 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC)
#define MACE_BM_MATMUL(M, K, N) \
MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Mace_SGemm, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \
......@@ -307,6 +292,7 @@ void MatMulBenchmark(
.Input("A")
.Input("B")
.Output("Output")
.OutputType({DT_INT32})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......@@ -408,7 +394,7 @@ void MatMulTransposeBenchmark(
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL_OP(1, 128, 128, 49);
MACE_BM_MATMUL_OP(1, 30000, 256, 1);
MACE_BM_MATMUL_OP(2, 128, 128, 49);
MACE_BM_MATMUL_OP(3, 128, 128, 49);
MACE_BM_MATMUL_OP(4, 128, 128, 49);
......
......@@ -23,7 +23,7 @@ namespace test {
class MatMulOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
template<DeviceType D>
void Simple(const std::vector<index_t> &A_shape,
const std::vector<float> &A_value,
const std::vector<index_t> &B_shape,
......@@ -55,12 +55,12 @@ TEST_F(MatMulOpTest, SimpleCPU) {
Simple<DeviceType::CPU>({1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 3, 2},
{1, 2, 3, 4, 5, 6}, {1, 2, 2}, {22, 28, 49, 64});
Simple<DeviceType::CPU>(
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
{1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25},
{1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610,
650, 765, 830, 895, 960, 1025, 1040, 1130, 1220,
{1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610,
650, 765, 830, 895, 960, 1025, 1040, 1130, 1220,
1310, 1400, 1315, 1430, 1545, 1660, 1775});
}
......@@ -289,9 +289,6 @@ TEST_F(MatMulOpTest, QuantOutputInt32) {
QuantOutputInt32({2}, 253, 300, 1, false, false);
}
// TODO(liyin): test transpose after implementing gpu runtime
// now transpose test is in kernels_test
} // namespace test
} // namespace ops
} // namespace mace
......@@ -14,7 +14,7 @@
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/opencl/helper.h"
namespace mace {
......
......@@ -14,7 +14,7 @@
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/opencl/helper.h"
namespace mace {
......
......@@ -23,7 +23,7 @@
#include "mace/ops/opencl/image/buffer_to_image.h"
#include "mace/ops/opencl/image/image_to_buffer.h"
#include "mace/ops/opencl/buffer/buffer_transform.h"
#include "mace/ops/transpose.h"
#include "mace/ops/common/transpose.h"
namespace mace {
namespace ops {
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "mace/ops/activation.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
class OpContext;
......
......@@ -17,8 +17,8 @@
#include <vector>
#include "mace/ops/activation.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
......
......@@ -18,7 +18,7 @@
#include <string>
#include <vector>
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
namespace mace {
......
......@@ -23,7 +23,7 @@
#include "mace/core/op_context.h"
#include "mace/core/tensor.h"
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/opencl/helper.h"
namespace mace {
......
......@@ -14,7 +14,7 @@
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/opencl/helper.h"
namespace mace {
......
......@@ -14,7 +14,7 @@
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/opencl/helper.h"
#include "mace/utils/utils.h"
......
......@@ -15,7 +15,7 @@
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/ops/opencl/helper.h"
#include "mace/ops/activation.h"
#include "mace/ops/common/activation_type.h"
#include "mace/utils/utils.h"
namespace mace {
......
......@@ -14,8 +14,8 @@
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/op_context.h"
#include "mace/ops/activation.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/activation_type.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/opencl/helper.h"
#include "mace/utils/utils.h"
......
......@@ -18,7 +18,7 @@
#include <vector>
#include "mace/ops/pooling.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
......
......@@ -36,6 +36,7 @@
#include "mace/public/mace.h"
#include "mace/utils/utils.h"
#include "mace/utils/quantize.h"
#include "mace/ops/testing/test_utils.h"
namespace mace {
namespace ops {
......@@ -422,230 +423,6 @@ class OpsTestBase : public ::testing::Test {
}
};
template <typename T>
void GenerateRandomRealTypeData(const std::vector<index_t> &shape,
std::vector<T> *res,
bool positive = true) {
MACE_CHECK_NOTNULL(res);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
res->resize(size);
if (DataTypeToEnum<T>::value == DT_HALF) {
std::generate(res->begin(), res->end(), [&gen, &nd, positive] {
return half_float::half_cast<half>(positive ? std::abs(nd(gen))
: nd(gen));
});
} else {
std::generate(res->begin(), res->end(), [&gen, &nd, positive] {
return positive ? std::abs(nd(gen)) : nd(gen);
});
}
}
template <typename T>
void GenerateRandomIntTypeData(const std::vector<index_t> &shape,
std::vector<T> *res,
const T a = 0,
const T b = std::numeric_limits<T>::max()) {
MACE_CHECK_NOTNULL(res);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> nd(a, b);
index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
res->resize(size);
std::generate(res->begin(), res->end(), [&gen, &nd] { return nd(gen); });
}
template <typename T>
std::vector<T> VectorStaticCast(const std::vector<float> &&src) {
std::vector<T> dest;
dest.reserve(src.size());
for (float f : src) {
dest.push_back(static_cast<T>(f));
}
return std::move(dest);
}
inline bool IsSameSize(const Tensor &x, const Tensor &y) {
if (x.dim_size() != y.dim_size()) return false;
for (int d = 0; d < x.dim_size(); ++d) {
if (x.dim(d) != y.dim(d)) return false;
}
return true;
}
inline std::string ShapeToString(const Tensor &x) {
std::stringstream stream;
for (int i = 0; i < x.dim_size(); i++) {
if (i > 0) stream << ",";
int64_t dim = x.dim(i);
if (dim < 0) {
stream << "?";
} else {
stream << dim;
}
}
stream << "]";
return std::string(stream.str());
}
template <typename T>
struct is_floating_point_type {
static const bool value = std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, half>::value;
};
template <typename T>
inline void ExpectEqual(const T &a, const T &b) {
EXPECT_EQ(a, b);
}
template <>
inline void ExpectEqual<float>(const float &a, const float &b) {
EXPECT_FLOAT_EQ(a, b);
}
template <>
inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b);
}
inline void AssertSameDims(const Tensor &x, const Tensor &y) {
ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs "
<< "y.shape [ " << ShapeToString(y) << "]";
}
template <typename EXP_TYPE,
typename RES_TYPE,
bool is_fp = is_floating_point_type<EXP_TYPE>::value>
struct Expector;
// Partial specialization for float and double.
template <typename EXP_TYPE, typename RES_TYPE>
struct Expector<EXP_TYPE, RES_TYPE, true> {
static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a[i], b[i]);
}
}
static void Near(const Tensor &x,
const Tensor &y,
const double rel_err,
const double abs_err) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
if (x.dim_size() == 4) {
for (int n = 0; n < x.dim(0); ++n) {
for (int h = 0; h < x.dim(1); ++h) {
for (int w = 0; w < x.dim(2); ++w) {
for (int c = 0; c < x.dim(3); ++c) {
const double error = abs_err + rel_err * std::abs(*a);
EXPECT_NEAR(*a, *b, error) << "with index = [" << n << ", " << h
<< ", " << w << ", " << c << "]";
a++;
b++;
}
}
}
}
} else {
for (int i = 0; i < x.size(); ++i) {
const double error = abs_err + rel_err * std::abs(a[i]);
EXPECT_NEAR(a[i], b[i], error) << "a = " << a << " b = " << b
<< " index = " << i;
}
}
}
};
template <typename EXP_TYPE, typename RES_TYPE>
struct Expector<EXP_TYPE, RES_TYPE, false> {
static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a[i], b[i]);
}
}
static void Near(const Tensor &x,
const Tensor &y,
const double rel_err,
const double abs_err) {
MACE_UNUSED(rel_err);
MACE_UNUSED(abs_err);
Equal(x, y);
}
};
template <typename T>
void ExpectTensorNear(const Tensor &x,
const Tensor &y,
const double rel_err = 1e-5,
const double abs_err = 1e-8) {
Expector<T, T>::Near(x, y, rel_err, abs_err);
}
template <typename EXP_TYPE, typename RES_TYPE>
void ExpectTensorNear(const Tensor &x,
const Tensor &y,
const double rel_err = 1e-5,
const double abs_err = 1e-8) {
Expector<EXP_TYPE, RES_TYPE>::Near(x, y, rel_err, abs_err);
}
template <typename T>
void ExpectTensorSimilar(const Tensor &x,
const Tensor &y,
const double abs_err = 1e-5) {
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto x_data = x.data<T>();
auto y_data = y.data<T>();
double dot_product = 0.0, x_norm = 0.0, y_norm = 0.0;
for (index_t i = 0; i < x.size(); i++) {
dot_product += x_data[i] * y_data[i];
x_norm += x_data[i] * x_data[i];
y_norm += y_data[i] * y_data[i];
}
double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm));
EXPECT_NEAR(1.0, similarity, abs_err);
}
} // namespace test
} // namespace ops
} // namespace mace
......
......@@ -27,7 +27,7 @@
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
#include "mace/ops/conv_pool_2d_base.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/pooling.h"
#include "mace/ops/opencl/buffer/pooling.h"
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/pooling.h"
#include "mace/ops/ops_test_util.h"
......
......@@ -15,7 +15,7 @@
#include <vector>
#include "mace/ops/pooling.h"
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ref/gemv.h"
#if defined(MACE_ENABLE_QUANTIZE)
#include "mace/utils/quantize.h"
#endif // MACE_ENABLE_QUANTIZE
namespace mace {
namespace ops {
namespace ref {
MaceStatus Gemv<float>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output) {
MACE_UNUSED(context);
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
const float *lhs_data = lhs->data<float>();
const float *rhs_data = rhs->data<float>();
const float *bias_data = nullptr;
if (bias) {
bias_data = bias->data<float>();
}
float *output_data = output->mutable_data<float>();
for (index_t b = 0; b < batch; ++b) {
for (index_t h = 0; h < lhs_height; ++h) {
float sum = bias ? bias_data[h] : 0;
for (index_t w = 0; w < lhs_width; ++w) {
sum += lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w]
* rhs_data[b * lhs_width + w];
} // w
output_data[b * lhs_height + h] = sum;
} // h
} // b
return MaceStatus::MACE_SUCCESS;
}
#if defined(MACE_ENABLE_QUANTIZE)
MaceStatus Gemv<uint8_t>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output) {
MACE_UNUSED(context);
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
const uint8_t *lhs_data = lhs->data<uint8_t>();
const uint8_t *rhs_data = rhs->data<uint8_t>();
const int32_t *bias_data = nullptr;
if (bias) {
bias_data = bias->data<int32_t>();
}
uint8_t *output_data = output->mutable_data<uint8_t>();
MACE_CHECK(output->scale() > 0, "output scale must not be zero");
const float
output_multiplier_float = lhs->scale() * rhs->scale() / output->scale();
int32_t lhs_zero = lhs->zero_point();
int32_t rhs_zero = rhs->zero_point();
for (index_t b = 0; b < batch; ++b) {
for (index_t h = 0; h < lhs_height; ++h) {
int32_t sum = bias ? bias_data[h] : 0;
for (index_t w = 0; w < lhs_width; ++w) {
sum += (lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] - lhs_zero)
* (rhs_data[b * lhs_width + w] - rhs_zero);
} // w
output_data[b * lhs_height + h] =
Saturate<uint8_t>(std::roundf(sum * output_multiplier_float));
} // h
} // b
return MaceStatus::MACE_SUCCESS;
}
MaceStatus Gemv<int32_t>::Compute(const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output) {
MACE_UNUSED(context);
Tensor::MappingGuard lhs_guard(lhs);
Tensor::MappingGuard rhs_guard(rhs);
Tensor::MappingGuard bias_guard(bias);
Tensor::MappingGuard output_guard(output);
const uint8_t *lhs_data = lhs->data<uint8_t>();
const uint8_t *rhs_data = rhs->data<uint8_t>();
const int32_t *bias_data = nullptr;
if (bias) {
bias_data = bias->data<int32_t>();
}
int32_t *output_data = output->mutable_data<int32_t>();
int32_t lhs_zero = lhs->zero_point();
int32_t rhs_zero = rhs->zero_point();
for (index_t b = 0; b < batch; ++b) {
for (index_t h = 0; h < lhs_height; ++h) {
int32_t sum = bias ? bias_data[h] : 0;
for (index_t w = 0; w < lhs_width; ++w) {
sum += (lhs_data[
static_cast<index_t>(lhs_batched) * b * lhs_height * lhs_width
+ h * lhs_width + w] - lhs_zero)
* (rhs_data[b * lhs_width + w] - rhs_zero);
} // w
output_data[b * lhs_height + h] = sum;
} // h
} // b
return MaceStatus::MACE_SUCCESS;
}
#endif // MACE_ENABLE_QUANTIZE
} // namespace ref
} // namespace ops
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REF_GEMV_H_
#define MACE_OPS_REF_GEMV_H_
#include "mace/public/mace.h"
#include "mace/core/tensor.h"
#include "mace/core/op_context.h"
namespace mace {
namespace ops {
namespace ref {
template<typename OUTPUT_TYPE>
class Gemv {
public:
Gemv() {}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output);
};
template<>
class Gemv<float> {
public:
Gemv() {}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output);
};
#if defined(MACE_ENABLE_QUANTIZE)
template<>
class Gemv<uint8_t> {
public:
Gemv() {}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output);
};
template<>
class Gemv<int32_t> {
public:
Gemv() {}
~Gemv() {}
// Always row-major after transpose
MaceStatus Compute(
const OpContext *context,
const Tensor *lhs,
const Tensor *rhs,
const Tensor *bias,
const index_t batch,
const index_t lhs_height,
const index_t lhs_width,
const bool lhs_batched,
Tensor *output);
};
#endif // MACE_ENABLE_QUANTIZE
} // namespace ref
} // namespace ops
} // namespace mace
#endif // MACE_OPS_REF_GEMV_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_TESTING_TEST_UTILS_H_
#define MACE_OPS_TESTING_TEST_UTILS_H_
#include <limits>
#include <map>
#include <memory>
#include <random>
#include <string>
#include <type_traits>
#include <utility>
#include <functional>
#include <vector>
#include "mace/core/tensor.h"
namespace mace {
namespace ops {
namespace test {
template<typename T>
void GenerateRandomRealTypeData(const std::vector<index_t> &shape,
T *res,
bool positive = true) {
MACE_CHECK_NOTNULL(res);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
if (DataTypeToEnum<T>::value == DT_HALF) {
std::generate(res, res + size, [&gen, &nd, positive] {
return half_float::half_cast<half>(positive ? std::abs(nd(gen))
: nd(gen));
});
} else {
std::generate(res, res + size, [&gen, &nd, positive] {
return positive ? std::abs(nd(gen)) : nd(gen);
});
}
}
template<typename T>
void GenerateRandomRealTypeData(const std::vector<index_t> &shape,
std::vector<T> *res,
bool positive = true) {
MACE_CHECK_NOTNULL(res);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
res->resize(size);
if (DataTypeToEnum<T>::value == DT_HALF) {
std::generate(res->begin(), res->end(), [&gen, &nd, positive] {
return half_float::half_cast<half>(positive ? std::abs(nd(gen))
: nd(gen));
});
} else {
std::generate(res->begin(), res->end(), [&gen, &nd, positive] {
return positive ? std::abs(nd(gen)) : nd(gen);
});
}
}
template<typename T>
void GenerateRandomIntTypeData(const std::vector<index_t> &shape,
T *res,
const T a = 0,
const T b = std::numeric_limits<T>::max()) {
MACE_CHECK_NOTNULL(res);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> nd(a, b);
index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
std::generate(res, res + size, [&gen, &nd] { return nd(gen); });
}
template<typename T>
void GenerateRandomIntTypeData(const std::vector<index_t> &shape,
std::vector<T> *res,
const T a = 0,
const T b = std::numeric_limits<T>::max()) {
MACE_CHECK_NOTNULL(res);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> nd(a, b);
index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>());
res->resize(size);
std::generate(res->begin(), res->end(), [&gen, &nd] { return nd(gen); });
}
template<typename T>
std::vector<T> VectorStaticCast(const std::vector<float> &&src) {
std::vector<T> dest;
dest.reserve(src.size());
for (float f : src) {
dest.push_back(static_cast<T>(f));
}
return std::move(dest);
}
inline bool IsSameSize(const Tensor &x, const Tensor &y) {
if (x.dim_size() != y.dim_size()) return false;
for (int d = 0; d < x.dim_size(); ++d) {
if (x.dim(d) != y.dim(d)) return false;
}
return true;
}
inline std::string ShapeToString(const Tensor &x) {
std::stringstream stream;
for (int i = 0; i < x.dim_size(); i++) {
if (i > 0) stream << ",";
int64_t dim = x.dim(i);
if (dim < 0) {
stream << "?";
} else {
stream << dim;
}
}
stream << "]";
return std::string(stream.str());
}
template<typename T>
struct is_floating_point_type {
static const bool value = std::is_same<T, float>::value ||
std::is_same<T, double>::value ||
std::is_same<T, half>::value;
};
template<typename T>
inline void ExpectEqual(const T &a, const T &b) {
EXPECT_EQ(a, b);
}
template<>
inline void ExpectEqual<float>(const float &a, const float &b) {
EXPECT_FLOAT_EQ(a, b);
}
template<>
inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b);
}
inline void AssertSameDims(const Tensor &x, const Tensor &y) {
ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs "
<< "y.shape [ " << ShapeToString(y) << "]";
}
template<typename EXP_TYPE,
typename RES_TYPE,
bool is_fp = is_floating_point_type<EXP_TYPE>::value>
struct Expector;
// Partial specialization for float and double.
template<typename EXP_TYPE, typename RES_TYPE>
struct Expector<EXP_TYPE, RES_TYPE, true> {
static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a[i], b[i]);
}
}
static void Near(const Tensor &x,
const Tensor &y,
const double rel_err,
const double abs_err) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
if (x.dim_size() == 4) {
for (int n = 0; n < x.dim(0); ++n) {
for (int h = 0; h < x.dim(1); ++h) {
for (int w = 0; w < x.dim(2); ++w) {
for (int c = 0; c < x.dim(3); ++c) {
const double error = abs_err + rel_err * std::abs(*a);
EXPECT_NEAR(*a, *b, error) << "with index = [" << n << ", " << h
<< ", " << w << ", " << c << "]";
a++;
b++;
}
}
}
}
} else {
for (int i = 0; i < x.size(); ++i) {
const double error = abs_err + rel_err * std::abs(a[i]);
EXPECT_NEAR(a[i], b[i], error) << "a = " << a << " b = " << b
<< " index = " << i;
}
}
}
};
template<typename EXP_TYPE, typename RES_TYPE>
struct Expector<EXP_TYPE, RES_TYPE, false> {
static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a[i], b[i]);
}
}
static void Near(const Tensor &x,
const Tensor &y,
const double rel_err,
const double abs_err) {
MACE_UNUSED(rel_err);
MACE_UNUSED(abs_err);
Equal(x, y);
}
};
template<typename T>
void ExpectTensorNear(const Tensor &x,
const Tensor &y,
const double rel_err = 1e-5,
const double abs_err = 1e-8) {
Expector<T, T>::Near(x, y, rel_err, abs_err);
}
template<typename EXP_TYPE, typename RES_TYPE>
void ExpectTensorNear(const Tensor &x,
const Tensor &y,
const double rel_err = 1e-5,
const double abs_err = 1e-8) {
Expector<EXP_TYPE, RES_TYPE>::Near(x, y, rel_err, abs_err);
}
template<typename T>
void ExpectTensorSimilar(const Tensor &x,
const Tensor &y,
const double abs_err = 1e-5) {
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto x_data = x.data<T>();
auto y_data = y.data<T>();
double dot_product = 0.0, x_norm = 0.0, y_norm = 0.0;
for (index_t i = 0; i < x.size(); i++) {
dot_product += x_data[i] * y_data[i];
x_norm += x_data[i] * x_data[i];
y_norm += y_data[i] * y_data[i];
}
double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm));
EXPECT_NEAR(1.0, similarity, abs_err);
}
} // namespace test
} // namespace ops
} // namespace mace
#endif // MACE_OPS_TESTING_TEST_UTILS_H_
......@@ -21,199 +21,11 @@
#include <vector>
#include "mace/core/operator.h"
#include "mace/ops/transpose.h"
#include "mace/ops/common/transpose.h"
namespace mace {
namespace ops {
namespace {
void TransposeNHWCToNCHWC3(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width * 3;
index_t out_offset = h * width;
#if defined(MACE_ENABLE_NEON)
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4x3_t vi = vld3q_f32(input + in_offset);
vst1q_f32(output + out_offset, vi.val[0]);
vst1q_f32(output + out_offset + image_size, vi.val[1]);
vst1q_f32(output + out_offset + image_size * 2, vi.val[2]);
in_offset += 12;
out_offset += 4;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 3; ++c) {
output[h * width + image_size * c + w] =
input[h * width * 3 + w * 3 + c];
}
}
#else
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < 3; ++c) {
output[out_offset + c * image_size + w] = input[in_offset + w * 3 + c];
}
}
#endif
}
}
void TransposeNCHWToNHWCC2(const float *input,
float *output,
const index_t height,
const index_t width) {
index_t image_size = height * width;
#pragma omp parallel for
for (index_t h = 0; h < height; ++h) {
index_t in_offset = h * width;
index_t out_offset = h * width * 2;
#if defined(MACE_ENABLE_NEON)
index_t w;
for (w = 0; w + 3 < width; w += 4) {
float32x4_t vi0 = vld1q_f32(input + in_offset);
float32x4_t vi1 = vld1q_f32(input + in_offset + image_size);
float32x4x2_t vi = {vi0, vi1};
vst2q_f32(output + out_offset, vi);
in_offset += 4;
out_offset += 8;
}
for (; w < width; ++w) {
for (index_t c = 0; c < 2; ++c) {
output[h * width * 2 + w * 2 + c] =
input[h * width + image_size * c + w];
}
}
#else
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < 2; ++c) {
output[out_offset + w * 2 + c] = input[in_offset + c * image_size + w];
}
}
#endif
}
}
} // namespace
MaceStatus Transpose(const float *input,
const std::vector<int64_t> &input_shape,
const std::vector<int> &dst_dims,
float *output) {
MACE_CHECK((input_shape.size() == 2 && dst_dims.size() == 2) ||
(input_shape.size() == 4 && dst_dims.size() == 4),
"Only support 2D or 4D transpose");
std::vector<index_t> output_shape;
for (size_t i = 0; i < dst_dims.size(); ++i) {
output_shape.push_back(input_shape[dst_dims[i]]);
}
if (input_shape.size() == 2) {
MACE_CHECK(dst_dims[0] == 1 && dst_dims[1] == 0, "no need transform");
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = height > 512 || width > 512 ? 64 : 32;
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
output[tile_j * stride_i + tile_i] =
input[tile_i * stride_j + tile_j];
}
}
}
}
} else if (input_shape.size() == 4) {
std::vector<int> transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2};
std::vector<int> transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1};
index_t batch_size = input_shape[1] * input_shape[2] * input_shape[3];
if (dst_dims == transpose_order_from_NHWC_to_NCHW && input_shape[3] == 3) {
for (index_t b = 0; b < input_shape[0]; ++b) {
TransposeNHWCToNCHWC3(input + b * batch_size,
output + b * batch_size,
input_shape[1],
input_shape[2]);
}
} else if (dst_dims == transpose_order_from_NCHW_to_NHWC
&& input_shape[1] == 2) {
for (index_t b = 0; b < input_shape[0]; ++b) {
TransposeNCHWToNHWCC2(input + b * batch_size,
output + b * batch_size,
input_shape[2],
input_shape[3]);
}
} else if (dst_dims == std::vector<int>{0, 2, 1, 3}) {
index_t height = input_shape[1];
index_t width = input_shape[2];
index_t channel = input_shape[3];
index_t channel_raw_size = channel * sizeof(float);
index_t stride_i = height;
index_t stride_j = width;
index_t tile_size = std::max(static_cast<index_t>(1),
static_cast<index_t>(std::sqrt(
8 * 1024 / channel)));
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < height; i += tile_size) {
for (index_t j = 0; j < width; j += tile_size) {
index_t end_i = std::min(i + tile_size, height);
index_t end_j = std::min(j + tile_size, width);
for (index_t tile_i = i; tile_i < end_i; ++tile_i) {
for (index_t tile_j = j; tile_j < end_j; ++tile_j) {
memcpy(output + (tile_j * stride_i + tile_i) * channel,
input + (tile_i * stride_j + tile_j) * channel,
channel_raw_size);
}
}
}
}
} else {
std::vector<index_t>
in_stride{input_shape[1] * input_shape[2] * input_shape[3],
input_shape[2] * input_shape[3], input_shape[3], 1};
std::vector<index_t>
out_stride{output_shape[1] * output_shape[2] * output_shape[3],
output_shape[2] * output_shape[3], output_shape[3], 1};
std::vector<index_t> idim(4, 0);
std::vector<index_t> odim(4, 0);
for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) {
for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) {
for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) {
for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) {
idim[dst_dims[0]] = odim[0];
idim[dst_dims[1]] = odim[1];
idim[dst_dims[2]] = odim[2];
idim[dst_dims[3]] = odim[3];
output[odim[0] * out_stride[0] + odim[1] * out_stride[1]
+ odim[2] * out_stride[2] + odim[3]] =
input[idim[0] * in_stride[0] + idim[1] * in_stride[1]
+ idim[2] * in_stride[2] + idim[3]];
}
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
return MaceStatus::MACE_SUCCESS;
}
template <DeviceType D, typename T>
class TransposeOp;
......
......@@ -31,6 +31,13 @@ class QuantizeStat(object):
if not enhance or samples <= 1:
res[tensor_name] = (tensor_min, tensor_max)
else:
"""
Enhancement mode:
This policy eliminates outliers that cause long-tail
statistical range. We try to reduce as much range as it could
while retaining more samples. d(range)/d(sample_quantile) is
used to measure this qualitatively.
"""
tensor_mins = np.sort(tensor_ranges[tensor_name][0])
tensor_maxs = np.sort(tensor_ranges[tensor_name][1])[::-1]
cur_min_idx = 0
......
......@@ -21,7 +21,7 @@
#include <string>
#include <vector>
#include "mace/ops/conv_pool_2d_util.h"
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
#include "mace/public/mace.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册