diff --git a/mace/libmace/BUILD b/mace/libmace/BUILD index 48312e8cc4beaa82cf5c04b2cef51e935401f738..1cecc7f60f86ca15904d40eb57188a2e42a83006 100644 --- a/mace/libmace/BUILD +++ b/mace/libmace/BUILD @@ -11,6 +11,7 @@ load( "//mace:mace.bzl", "if_android", "if_neon_enabled", + "if_neon_enabled_str", "if_openmp_enabled", "if_android_armv7", "if_hexagon_enabled", @@ -41,8 +42,8 @@ cc_library( "-DMACE_ENABLE_HEXAGON", ]), deps = [ - "//mace/public", "//mace/ops", + "//mace/public", ], alwayslink = 1, ) @@ -59,8 +60,8 @@ cc_binary( linkshared = 1, linkstatic = 0, deps = [ - "//mace/libmace:mace_version_script.lds", "//mace/libmace", + "//mace/libmace:mace_version_script.lds", ], ) @@ -81,6 +82,8 @@ genrule( srcs = [ "//mace/codegen:generated_version", "//mace/core", + "//mace/ops:common", + "//mace/ops:ref_kernels", "//mace/ops:internal_ops", "//mace/ops", "//mace/libmace", @@ -88,13 +91,20 @@ genrule( "//mace/proto:mace_cc", "@com_google_protobuf//:protobuf_lite", ] + if_opencl_enabled([ + "//mace/ops:opencl_kernels", "//mace/codegen:generated_opencl", + ]) + if_neon_enabled([ + "//mace/ops:arm_neon_kernels", ]), outs = ["libmace.a"], cmd = "tmp_mri_file=$$(mktemp mace-static-lib-mri.XXXXXXXXXX);" + "mri_stream=$$(python $(location //mace/python/tools:archive_static_lib) " + "$(locations //mace/codegen:generated_version) " + "$(locations //mace/core:core) " + + "$(locations //mace/ops:common) " + + "$(locations //mace/ops:ref_kernels) " + + if_neon_enabled_str("$(locations //mace/ops:arm_neon_kernels) ") + + if_opencl_enabled_str("$(locations //mace/ops:opencl_kernels) ") + "$(locations //mace/ops:internal_ops) " + "$(locations //mace/ops:ops) " + "$(locations //mace/libmace:libmace) " + diff --git a/mace/libmace/capability.cc b/mace/libmace/capability.cc index e7c613e7b06539a53fcf367ffc9de8b727e11d69..2989cbc16f8432842858af66e7682678d7a09f2f 100644 --- a/mace/libmace/capability.cc +++ b/mace/libmace/capability.cc @@ -17,7 +17,7 @@ #include #include -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/public/mace.h" namespace mace { diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index ff812c6bb893af994299a8a99820afce8c21ff58..12bdf2be5f511e5ba3d79d230f1270e8db9ed158 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -26,7 +26,7 @@ #include "mace/core/memory_optimizer.h" #include "mace/core/net.h" #include "mace/ops/ops_registry.h" -#include "mace/ops/transpose.h" +#include "mace/ops/common/transpose.h" #include "mace/public/mace.h" #ifdef MACE_ENABLE_OPENCL diff --git a/mace/mace.bzl b/mace/mace.bzl index 0215a08627bfa2472b596dc42801f8452de9be82..2afe4560e323d2ad1cbe731832c5a918b09b177b 100644 --- a/mace/mace.bzl +++ b/mace/mace.bzl @@ -42,6 +42,12 @@ def if_neon_enabled(a): "//conditions:default": [], }) +def if_neon_enabled_str(a): + return select({ + "//mace:neon_enabled": a, + "//conditions:default": "", + }) + def if_hexagon_enabled(a): return select({ "//mace:hexagon_enabled": a, diff --git a/mace/ops/BUILD b/mace/ops/BUILD index cc26ed7503ee8be4c76c84bade5e76e34c4378fc..7f03ce12221a7e074e59a34cdb38f918b86ff51a 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -18,61 +18,323 @@ load( ) cc_library( - name = "internal_ops", + name = "common", srcs = glob( [ - "*.cc", - "arm/*.cc", + "common/*.cc", + ], + ), + hdrs = glob( + [ + "common/*.h", + ], + ), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + "//mace/core", + ], +) + +cc_library( + name = "testing", + srcs = glob( + [ + "testing/*.cc", + ], + ), + hdrs = glob( + [ + "testing/*.h", + ], + ), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + "//mace/core", + "@gtest//:gtest", + ], +) + +cc_library( + name = "ref_kernels", + srcs = glob( + [ + "ref/*.cc", + ], + ), + hdrs = glob( + [ + "ref/*.h", + ], + ), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + ":common", + "//mace/core", + ], +) + +# After refactor, all arm neon kernels go here. +# Could be shipped to other product use. +cc_library( + name = "arm_neon_kernels", + srcs = glob( + [ + "arm/fp32/*.cc", ], exclude = [ - "*_test.cc", - "*_benchmark.cc", - "arm/*_test.cc", - "ops_registry.cc", - "ops_test_util.cc", - "buffer_transform.cc", - "lstm_cell.cc", - "quantize.cc", - "quantization_util.cc", + "arm/fp32/*_test.cc", + ], + ) + if_quantize_enabled(glob( + [ + "arm/q8/*.cc", ], - ) + if_opencl_enabled(glob( + exclude = [ + "arm/q8/*_test.cc", + ], + )), + hdrs = glob( + [ + "arm/fp32/*.h", + ], + ) + if_quantize_enabled(glob( + [ + "arm/q8/*.h", + ], + )), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + ":common", + "//mace/core", + ], +) + +# After refactor, all GPU OpenCL kernels go here. +# Could be shipped to other product use. +cc_library( + name = "opencl_kernels", + srcs = glob( [ "opencl/*.cc", - "opencl/image/*.cc", - "opencl/buffer/*.cc", + "opencl/**/*.cc", "buffer_transform.cc", "lstm_cell.cc", ], exclude = [ "opencl/*_test.cc", ], - )) + if_quantize_enabled(glob( + ), + hdrs = glob( + [ + "opencl/*.h", + "opencl/**/*.h", + ], + ), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + ":common", + "//mace/core", + ], +) + +cc_library( + name = "arm_neon_kernels_test", + srcs = glob( + [ + "arm/fp32/*_test.cc", + ], + ) + if_quantize_enabled(glob( + [ + "arm/q8/*_test.cc", + ], + )), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + ":arm_neon_kernels", + ":ref_kernels", + ":testing", + "@gtest//:gtest", + ], + alwayslink = 1, +) + +cc_library( + name = "opencl_kernels_test", + srcs = glob( + [ + "opencl/*_test.cc", + "opencl/**/*_test.cc", + ], + ), + copts = [ + "-Werror", + "-Wextra", + "-Wno-missing-field-initializers", + ] + if_openmp_enabled([ + "-fopenmp", + ]) + if_neon_enabled([ + "-DMACE_ENABLE_NEON", + ]) + if_android_armv7([ + "-mfpu=neon", + "-mfloat-abi=softfp", + ]) + if_opencl_enabled([ + "-DMACE_ENABLE_OPENCL", + ]) + if_quantize_enabled([ + "-DMACE_ENABLE_QUANTIZE", + ]) + if_hexagon_enabled([ + "-DMACE_ENABLE_HEXAGON", + ]), + deps = [ + ":opencl_kernels", + ":ref_kernels", + ":testing", + "@gtest//:gtest", + ], + alwayslink = 1, +) + +cc_library( + name = "internal_ops", + srcs = glob( [ + "*.cc", + "arm/*.cc", # remove it after refactor + ], + exclude = [ + "*_test.cc", + "*_benchmark.cc", + "ops_registry.cc", + "ops_test_util.cc", + "lstm_cell.cc", # TODO: move it into opencl + "buffer_transform.cc", # TODO: move it into opencl "quantize.cc", "quantization_util.cc", + "arm/*_test.cc", # remove it after refactor ], - )), + ) + if_quantize_enabled( + glob( + [ + "quantize.cc", + "quantization_util.cc", + ], + ), + ), hdrs = glob( [ "*.h", - "arm/*.h", + "arm/*.h", # remove it after refactor ], exclude = [ "ops_registry.h", "ops_test_util.h", "fixpoint.h", "gemmlowp_util.h", - "arm/fixpoint_*.h", "quantization_util.h", ], - ) + if_opencl_enabled(glob([ - "opencl/*.h", - "opencl/image/*.h", - "opencl/buffer/*.h", - ])) + if_quantize_enabled(glob([ + ) + if_quantize_enabled(glob([ "fixpoint.h", "gemmlowp_util.h", - "arm/fixpoint_*.h", "quantization_util.h", ])), copts = [ @@ -85,7 +347,6 @@ cc_library( "-DMACE_ENABLE_NEON", ]) + if_android_armv7([ "-mfpu=neon", - ]) + if_android_armv7([ "-mfloat-abi=softfp", ]) + if_opencl_enabled([ "-DMACE_ENABLE_OPENCL", @@ -96,21 +357,30 @@ cc_library( ]), linkopts = if_android(["-lm"]), deps = [ + ":ref_kernels", "//mace/core", ] + if_quantize_enabled([ "@tflite", "@gemmlowp", + ]) + if_neon_enabled([ + ":arm_neon_kernels", + ]) + if_opencl_enabled([ + ":opencl_kernels", ]), ) cc_library( name = "ops", - srcs = [ - "ops_registry.cc", - ], - hdrs = [ - "ops_registry.h", - ], + srcs = glob( + [ + "ops_registry.cc", + ], + ), + hdrs = glob( + [ + "ops_registry.h", + ], + ), copts = [ "-Werror", "-Wextra", @@ -121,7 +391,6 @@ cc_library( "-DMACE_ENABLE_NEON", ]) + if_android_armv7([ "-mfpu=neon", - ]) + if_android_armv7([ "-mfloat-abi=softfp", ]) + if_opencl_enabled([ "-DMACE_ENABLE_OPENCL", @@ -161,6 +430,7 @@ cc_library( ]), deps = [ "ops", + "testing", "@gtest", ], ) @@ -171,8 +441,8 @@ cc_test( srcs = glob( [ "*_test.cc", - "arm/*_test.cc", - "opencl/*_test.cc", + "arm/*_test.cc", # remove it after refactor + "ops_test_util.cc", ], exclude = [ "fixpoint_test.cc", @@ -203,9 +473,14 @@ cc_test( linkopts = ["-fopenmp"], linkstatic = 1, deps = [ - "test", + ":ops", + ":test", "@gtest//:gtest_main", - ], + ] + if_neon_enabled([ + ":arm_neon_kernels_test", + ]) + if_opencl_enabled([ + ":opencl_kernels_test", + ]), ) cc_test( @@ -233,7 +508,7 @@ cc_test( linkopts = ["-fopenmp"], linkstatic = 1, deps = [ - "test", + ":ops", "//mace/benchmark:statistics", "//mace/core:test_benchmark_main", "//third_party/eigen3", diff --git a/mace/ops/activation.h b/mace/ops/activation.h index d9fc5de7a5bc39edf64da6dd5e4c09f383a0d657..9981652c78d4290289fc2ce8392adc6550fe267c 100644 --- a/mace/ops/activation.h +++ b/mace/ops/activation.h @@ -20,22 +20,13 @@ #include #include "mace/core/types.h" +#include "mace/ops/common/activation_type.h" #include "mace/ops/arm/activation_neon.h" #include "mace/utils/logging.h" namespace mace { namespace ops { -enum ActivationType { - NOOP = 0, - RELU = 1, - RELUX = 2, - PRELU = 3, - TANH = 4, - SIGMOID = 5, - LEAKYRELU = 6, -}; - inline ActivationType StringToActivationType(const std::string type) { if (type == "RELU") { return ActivationType::RELU; diff --git a/mace/ops/arm/README b/mace/ops/arm/README new file mode 100644 index 0000000000000000000000000000000000000000..3ac8072e9d82065ff04c5c252a05421188181452 --- /dev/null +++ b/mace/ops/arm/README @@ -0,0 +1,15 @@ +# Notes for Contributors and Roadmap + +We are going to refactor and optimize ARM related kernels one step at a time. + +The code structure will be organized as that kernels for each data type are separated. +By this way, they are independent to each other and can be linked and shipped as a submodule, +and it saves us from writing macro boilerplate. The reason we do not use a unified header file for each kernel +is that we are not forcing developers to use exact same interface for kernels of different data types at this level, +although doing this is recommended if convenient to do so. A reference version is put right in `ops` directory to be used +as non-NEON kernels, which are not separated for different data types for simplicity. + +Although interface is kept flexible and can be defined according to demand, +input/output parameters should be of `Tensor` type instead of raw pointer, as +`Tensor` has more information kernel might use. + diff --git a/mace/ops/arm/conv_winograd.cc b/mace/ops/arm/conv_winograd.cc index 5a5c3f9acfc49c9b7160d056f67560f43b7a03b3..11d4fbf0d52eac3d8c7abab87a5f5b95693c5df5 100644 --- a/mace/ops/arm/conv_winograd.cc +++ b/mace/ops/arm/conv_winograd.cc @@ -15,7 +15,6 @@ #include #include "mace/ops/arm/conv_winograd.h" -#include "mace/ops/gemm.h" namespace mace { namespace ops { diff --git a/mace/ops/arm/fixpoint_gemm.h b/mace/ops/arm/fixpoint_gemm.h deleted file mode 100644 index 28b8eafaa12138fc08d7021100c706035400f739..0000000000000000000000000000000000000000 --- a/mace/ops/arm/fixpoint_gemm.h +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright 2018 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_OPS_ARM_FIXPOINT_GEMM_H_ -#define MACE_OPS_ARM_FIXPOINT_GEMM_H_ - -#if defined(MACE_ENABLE_NEON) -#include -#endif - -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) -#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) -#endif - -namespace mace { -namespace ops { - -template -void FixPointGemv(const INPUT_TYPE *lhs, - const INPUT_TYPE *rhs, - const int lhs_zero_point, - const int rhs_zero_point, - const index_t lhs_height, - const index_t lhs_width, - OUTPUT_TYPE *result); - -template<> -void FixPointGemv(const uint8_t *lhs, - const uint8_t *rhs, - const int lhs_zero_point, - const int rhs_zero_point, - const index_t lhs_height, - const index_t lhs_width, - int32_t *result) { - int32_t zero_point_dot = lhs_zero_point * rhs_zero_point * lhs_width; - - uint32_t sum_rhs = 0; - for (index_t i = 0; i < lhs_width; ++i) { - sum_rhs += rhs[i]; - } - -#pragma omp parallel for - for (index_t h = 0; h < lhs_height; ++h) { - const uint8_t *lhs_ptr = lhs + h * lhs_width; - const uint8_t *rhs_ptr = rhs; - int32_t *ret_ptr = result + h; - - uint32_t dot = 0; - uint32_t sum_lhs = 0; - index_t w = 0; - -#if defined(MACE_ENABLE_NEON) - uint32x4_t vo0_high_u32, vo0_low_u32, vo1_high_u32, vo1_low_u32; - vo0_high_u32 = vdupq_n_u32(0); - vo0_low_u32 = vdupq_n_u32(0); - vo1_high_u32 = vdupq_n_u32(0); - vo1_low_u32 = vdupq_n_u32(0); - - uint32x4_t sum_lhs_low_u32, sum_lhs_high_u32; - sum_lhs_low_u32 = vdupq_n_u32(0); - sum_lhs_high_u32 = vdupq_n_u32(0); - - for (; w <= lhs_width - 16; w += 16) { - uint8x8_t vl0_u8, vl1_u8; - uint8x8_t vr0_u8, vr1_u8; - uint16x8_t vl0_u16, vl1_u16; - uint16x8_t vr0_u16, vr1_u16; - - vl0_u8 = vld1_u8(lhs_ptr); - vl1_u8 = vld1_u8(lhs_ptr + 8); - - vr0_u8 = vld1_u8(rhs_ptr); - vr1_u8 = vld1_u8(rhs_ptr + 8); - - vl0_u16 = vmovl_u8(vl0_u8); - vl1_u16 = vmovl_u8(vl1_u8); - - vr0_u16 = vmovl_u8(vr0_u8); - vr1_u16 = vmovl_u8(vr1_u8); - - vo0_high_u32 = vmlal_u16(vo0_high_u32, - vget_high_u16(vl0_u16), - vget_high_u16(vr0_u16)); - vo0_low_u32 = vmlal_u16(vo0_low_u32, - vget_low_u16(vl0_u16), - vget_low_u16(vr0_u16)); - vo1_high_u32 = vmlal_u16(vo1_high_u32, - vget_high_u16(vl1_u16), - vget_high_u16(vr1_u16)); - vo1_low_u32 = vmlal_u16(vo1_low_u32, - vget_low_u16(vl1_u16), - vget_low_u16(vr1_u16)); - - // It can be precuculated if lhs is const, but for this case - // computation is not bottleneck - sum_lhs_high_u32 += vaddl_u16(vget_high_u16(vl0_u16), - vget_high_u16(vl1_u16)); - sum_lhs_low_u32 += vaddl_u16(vget_low_u16(vl0_u16), - vget_low_u16(vl1_u16)); - - lhs_ptr += 16; - rhs_ptr += 16; - } - vo0_low_u32 = vaddq_u32(vo0_high_u32, vo0_low_u32); - vo1_low_u32 = vaddq_u32(vo1_high_u32, vo1_low_u32); - vo0_low_u32 = vaddq_u32(vo0_low_u32, vo1_low_u32); - dot += vaddvq_u32(vo0_low_u32); - - sum_lhs_low_u32 = vaddq_u32(sum_lhs_high_u32, sum_lhs_low_u32); - sum_lhs = vaddvq_u32(sum_lhs_low_u32); -#endif // MACE_ENABLE_NEON - - for (; w < lhs_width; ++w) { - dot += (*lhs_ptr) * (*rhs_ptr); - sum_lhs += (*lhs_ptr); - ++lhs_ptr; - ++rhs_ptr; - } - - int32_t ret = dot - sum_lhs * rhs_zero_point - sum_rhs * lhs_zero_point - + zero_point_dot; - - *ret_ptr = ret; - } // h -} - -} // namespace ops -} // namespace mace - -#endif // MACE_OPS_ARM_FIXPOINT_GEMM_H_ diff --git a/mace/ops/arm/fp32/gemv.cc b/mace/ops/arm/fp32/gemv.cc new file mode 100644 index 0000000000000000000000000000000000000000..39b25bf584e6a580aa84f5d1adeadcab8402d267 --- /dev/null +++ b/mace/ops/arm/fp32/gemv.cc @@ -0,0 +1,302 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/arm/fp32/gemv.h" + +#include +#include + +#if !defined(__aarch64__) +#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) +#endif + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +MaceStatus Gemv::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + + const index_t h_block_size = 4; + const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size); + const index_t w_block_size = 8; + const index_t w_block_count = lhs_width / w_block_size; + const index_t w_remain = lhs_width - w_block_size * w_block_count; + +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t b = 0; b < batch; ++b) { + for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) { + // TODO(liyin): it can be put it outside the loop, + // but openmp limits param count + const float *lhs_data = lhs->data(); + const float *rhs_data = rhs->data(); + const float *bias_data = nullptr; + if (bias) { + bias_data = bias->data(); + } + float *output_data = output->mutable_data(); + + const float + *lhs_ptr = lhs_data + + static_cast(lhs_batched) * b * lhs_height * lhs_width + + lhs_width * h_block_idx * h_block_size; + const float *rhs_ptr = rhs_data + b * lhs_width; + float + *ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; + + const index_t h_block_len = + std::min(h_block_size, lhs_height - h_block_idx * h_block_size); + const index_t h_offset = h_block_idx * h_block_size; + + if (h_block_len == 4) { + float32x4_t vo0 = vdupq_n_f32(0); + float32x4_t vo1 = vdupq_n_f32(0); + float32x4_t vo2 = vdupq_n_f32(0); + float32x4_t vo3 = vdupq_n_f32(0); + + index_t r_w_block_count = w_block_count; + // just make compiler happy + MACE_UNUSED(r_w_block_count); + + // Register layout: (4x8) x (8,1) + // + // +----+ + // |d16 | + // | . | + // Rhs +----+ + // |d17 | + // | . | + // +----+ + // |d18 | + // | . | + // +----+ + // |d19 | + // | . | + // +----+ + // + // | | + // + // Lhs | | + // + // +------+------+----+-----+ - - - - +----+ + // | d0 . | d1 .| d2 .| d3 .| |vo0 | + // | d4 . | d5 .| d6 .| d7 .| |vo1 | + // | d8 . | d9 .| d10.| d11.| |vo2 | + // | d12. | d13.| d14.| d15.| |vo3 | + // +------+-----+-----+-----+ - - - - +----+ + // + // Accumulator + // + +#if not defined(__aarch64__) + asm volatile( + "cmp %[r_w_block_count], #0\n" + "beq 0f\n" + + "lsl r5, %[lhs_width], #2\n" + + "mov r0, %[rhs_ptr]\n" + "mov r1, %[lhs_ptr]\n" + "add r2, r1, r5\n" + "add r3, r2, r5\n" + "add r4, r3, r5\n" + + // prelogue + "vld1.f32 {d16-d17}, [r0]!\n" + "vld1.f32 {d18-d19}, [r0]!\n" + + "vld1.f32 {d0-d1}, [r1]!\n" + "vld1.f32 {d2-d3}, [r1]!\n" + "vld1.f32 {d4-d5}, [r2]!\n" + "vld1.f32 {d6-d7}, [r2]!\n" + "vld1.f32 {d8-d9}, [r3]!\n" + "vld1.f32 {d10-d11}, [r3]!\n" + "vld1.f32 {d12-d13}, [r4]!\n" + "vld1.f32 {d14-d15}, [r4]!\n" + + "subs %[r_w_block_count], #1\n" + "beq 1f\n" + + "2: \n" + "vmla.f32 %q[vo0], q0, q8\n" + "vmla.f32 %q[vo1], q2, q8\n" + "vmla.f32 %q[vo2], q4, q8\n" + "vmla.f32 %q[vo3], q6, q8\n" + + + "vmla.f32 %q[vo0], q1, q9\n" + "vmla.f32 %q[vo1], q3, q9\n" + "vmla.f32 %q[vo2], q5, q9\n" + "vmla.f32 %q[vo3], q7, q9\n" + + "subs %[r_w_block_count], #1\n" + + + "vld1.f32 {d0-d1}, [r1]!\n" + "vld1.f32 {d4-d5}, [r2]!\n" + "vld1.f32 {d8-d9}, [r3]!\n" + "vld1.f32 {d12-d13}, [r4]!\n" + "vld1.f32 {d16-d17}, [r0]!\n" + + "vld1.f32 {d2-d3}, [r1]!\n" + "vld1.f32 {d6-d7}, [r2]!\n" + "vld1.f32 {d10-d11}, [r3]!\n" + "vld1.f32 {d14-d15}, [r4]!\n" + "vld1.f32 {d18-d19}, [r0]!\n" + + "bne 2b\n" + + // prologue + "1:\n" + "vmla.f32 %q[vo0], q0, q8\n" + "vmla.f32 %q[vo1], q2, q8\n" + "vmla.f32 %q[vo2], q4, q8\n" + "vmla.f32 %q[vo3], q6, q8\n" + + "vmla.f32 %q[vo0], q1, q9\n" + "vmla.f32 %q[vo1], q3, q9\n" + "vmla.f32 %q[vo2], q5, q9\n" + "vmla.f32 %q[vo3], q7, q9\n" + + "0:\n" + : // outputs + [vo0] "+w"(vo0), + [vo1] "+w"(vo1), + [vo2] "+w"(vo2), + [vo3] "+w"(vo3), + [r_w_block_count] "+r"(r_w_block_count) + : // inputs + [lhs_ptr] "r"(lhs_ptr), [rhs_ptr] "r"(rhs_ptr), + [lhs_width] "r"(lhs_width) + : // clobbers + "cc", "memory", "r0", "r1", "r2", "r3", "r4", "r5", + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21"); + + lhs_ptr += w_block_count * w_block_size; + rhs_ptr += w_block_count * w_block_size; +#else + for (index_t w_block_index = 0; w_block_index < w_block_count; + ++w_block_index) { + float32x4_t vr0 = vld1q_f32(rhs_ptr); + float32x4_t vr0n = vld1q_f32(rhs_ptr + 4); + + float32x4_t vl0 = vld1q_f32(lhs_ptr); + float32x4_t vl0n = vld1q_f32(lhs_ptr + 4); + vo0 = vmlaq_f32(vo0, vl0, vr0); + vo0 = vmlaq_f32(vo0, vl0n, vr0n); + + const float *lhs_ptr1 = lhs_ptr + lhs_width; + float32x4_t vl1 = vld1q_f32(lhs_ptr1); + float32x4_t vl1n = vld1q_f32(lhs_ptr1 + 4); + vo1 = vmlaq_f32(vo1, vl1, vr0); + vo1 = vmlaq_f32(vo1, vl1n, vr0n); + + const float *lhs_ptr2 = lhs_ptr1 + lhs_width; + float32x4_t vl2 = vld1q_f32(lhs_ptr2); + float32x4_t vl2n = vld1q_f32(lhs_ptr2 + 4); + vo2 = vmlaq_f32(vo2, vl2, vr0); + vo2 = vmlaq_f32(vo2, vl2n, vr0n); + + const float *lhs_ptr3 = lhs_ptr2 + lhs_width; + float32x4_t vl3 = vld1q_f32(lhs_ptr3); + float32x4_t vl3n = vld1q_f32(lhs_ptr3 + 4); + vo3 = vmlaq_f32(vo3, vl3, vr0); + vo3 = vmlaq_f32(vo3, vl3n, vr0n); + + lhs_ptr += 8; + rhs_ptr += 8; + } +#endif // __aarch64__ + float32x4_t vo = { + vaddvq_f32(vo0), + vaddvq_f32(vo1), + vaddvq_f32(vo2), + vaddvq_f32(vo3) + }; + for (index_t w = 0; w < w_remain; ++w) { + vo[0] += lhs_ptr[0] * rhs_ptr[0]; + vo[1] += lhs_ptr[lhs_width] * rhs_ptr[0]; + vo[2] += lhs_ptr[lhs_width * 2] * rhs_ptr[0]; + vo[3] += lhs_ptr[lhs_width * 3] * rhs_ptr[0]; + ++lhs_ptr; + ++rhs_ptr; + } + + float32x4_t vbias = vdupq_n_f32(0); + if (bias) { + vbias = vld1q_f32(bias_data + h_offset); + } + vo = vaddq_f32(vo, vbias); + vst1q_f32(ret_ptr, vo); + } else { // h_block_len < 4 + // TODO(liyin): handle here case by case (1,2,3) to accelerate + const float *tmp_lhs_ptr = lhs_ptr; + const float *tmp_rhs_ptr = rhs_ptr; + for (index_t h = 0; h < h_block_len; ++h) { + lhs_ptr = tmp_lhs_ptr + h * lhs_width; + rhs_ptr = tmp_rhs_ptr; + float32x4_t vo0 = vdupq_n_f32(0); + for (index_t w = 0; w < w_block_count; ++w) { + float32x4_t vr0 = vld1q_f32(rhs_ptr); + float32x4_t vr0n = vld1q_f32(rhs_ptr + 4); + + float32x4_t vl0 = vld1q_f32(lhs_ptr); + float32x4_t vl0n = vld1q_f32(lhs_ptr + 4); + vo0 = vmlaq_f32(vo0, vl0, vr0); + vo0 = vmlaq_f32(vo0, vl0n, vr0n); + + lhs_ptr += 8; + rhs_ptr += 8; + } // w + float s0 = vaddvq_f32(vo0) + (bias ? bias_data[h_offset + h] : 0); + for (index_t w = 0; w < w_remain; ++w) { + s0 += lhs_ptr[0] * rhs_ptr[0]; + ++lhs_ptr; + ++rhs_ptr; + } // w + + ret_ptr[h] = s0; + } // h + } // if + } // h_block_idx + } // b + + return MaceStatus::MACE_SUCCESS; +} + +#if defined(vaddvq_f32) +#undef vaddvq_f32 +#endif + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/fp32/gemv.h b/mace/ops/arm/fp32/gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..5b1551fa0be1c8e0a22deb5c20be5772f13be785 --- /dev/null +++ b/mace/ops/arm/fp32/gemv.h @@ -0,0 +1,49 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_FP32_GEMV_H_ +#define MACE_OPS_ARM_FP32_GEMV_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" + +namespace mace { +namespace ops { +namespace arm { +namespace fp32 { + +class Gemv { + public: + Gemv() {} + ~Gemv() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output); +}; + +} // namespace fp32 +} // namespace arm +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_FP32_GEMV_H_ diff --git a/mace/ops/arm/fp32/gemv_test.cc b/mace/ops/arm/fp32/gemv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2d13b3896d0f01f73554d8b2d5b0a989971bb09 --- /dev/null +++ b/mace/ops/arm/fp32/gemv_test.cc @@ -0,0 +1,99 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/arm/fp32/gemv.h" +#include "mace/ops/ref/gemv.h" +#include "mace/ops/testing/test_utils.h" + +namespace mace { +namespace ops { +namespace test { + +void TestGemvFloat32(const index_t batch, + const index_t height, + const index_t width, + const bool lhs_batched) { + Tensor lhs(GetCPUAllocator(), DataType::DT_FLOAT); + Tensor rhs(GetCPUAllocator(), DataType::DT_FLOAT); + Tensor bias(GetCPUAllocator(), DataType::DT_FLOAT); + Tensor output(GetCPUAllocator(), DataType::DT_FLOAT); + lhs.Resize({lhs_batched ? batch : 1, height, width}); + rhs.Resize({batch, width}); + bias.Resize({height}); + output.Resize({batch, height}); + { + Tensor::MappingGuard lhs_guard(&lhs); + Tensor::MappingGuard rhs_guard(&rhs); + Tensor::MappingGuard bias_guard(&bias); + float *lhs_data = lhs.mutable_data(); + float *rhs_data = rhs.mutable_data(); + float *bias_data = bias.mutable_data(); + GenerateRandomRealTypeData(lhs.shape(), lhs_data); + GenerateRandomRealTypeData(rhs.shape(), rhs_data); + GenerateRandomRealTypeData(bias.shape(), bias_data); + } + ::mace::ops::arm::fp32::Gemv gemv; + gemv.Compute(nullptr, + &lhs, + &rhs, + &bias, + batch, + height, + width, + lhs_batched, + &output); + + Tensor expected_output(GetCPUAllocator(), DataType::DT_FLOAT); + expected_output.Resize({batch, height}); + ::mace::ops::ref::Gemv gemv_ref; + gemv_ref.Compute(nullptr, + &lhs, + &rhs, + &bias, + batch, + height, + width, + lhs_batched, + &expected_output); + + Tensor::MappingGuard output_guard(&output); + Tensor::MappingGuard expected_guard(&expected_output); + const float *output_data = output.data(); + const float *expected_data = expected_output.data(); + + for (index_t i = 0; i < output.size(); ++i) { + EXPECT_NEAR(expected_data[i], output_data[i], 0.001); + } +} + +TEST(ArmGemv, TestGemvFloat32) { + TestGemvFloat32(1, 16, 4, true); + TestGemvFloat32(1, 16, 256, true); + TestGemvFloat32(2, 16, 256, true); + TestGemvFloat32(3, 63, 257, true); + + TestGemvFloat32(1, 16, 4, false); + TestGemvFloat32(1, 16, 256, false); + TestGemvFloat32(2, 16, 256, false); + TestGemvFloat32(3, 63, 257, false); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/q8/gemv.cc b/mace/ops/arm/q8/gemv.cc new file mode 100644 index 0000000000000000000000000000000000000000..7117dcac6cee8e0b75eeffdff9741db0e06cfaac --- /dev/null +++ b/mace/ops/arm/q8/gemv.cc @@ -0,0 +1,470 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/arm/q8/gemv.h" + +#include +#include + +#include "mace/utils/utils.h" +#include "mace/utils/quantize.h" + +#if !defined(__aarch64__) + +#define vmlal_high_s16(c, a, b) vmlal_s16(c, vget_high_s16(a), vget_high_s16(b)) + +#define vaddvq_s32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) + +#endif + +namespace mace { +namespace ops { +namespace arm { +namespace q8 { + +template +MaceStatus Gemv::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + bool is_output_type_uint8 = + DataTypeToEnum::value == DataType::DT_UINT8; + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + + float output_multiplier_float = 0.0; + int32_t output_multiplier = 0; + int32_t output_shift = 0; + if (is_output_type_uint8) { + MACE_CHECK(output->scale() > 0, "output scale must not be zero"); + output_multiplier_float = lhs->scale() * rhs->scale() / output->scale(); + GetOutputMultiplierAndShift(lhs->scale(), + rhs->scale(), + output->scale(), + &output_multiplier, + &output_shift); + } + const index_t h_block_size = 4; + const index_t h_block_count = RoundUpDiv(lhs_height, h_block_size); + +#pragma omp parallel for collapse(2) schedule(runtime) + for (index_t b = 0; b < batch; ++b) { + for (index_t h_block_idx = 0; h_block_idx < h_block_count; ++h_block_idx) { + // TODO(liyin): it can be put it outside the loop, + // but openmp limits param count + const index_t w_block_size = 16; + const index_t w_block_count = lhs_width / w_block_size; + const index_t w_remain = lhs_width - w_block_size * w_block_count; + + uint8_t lhs_zero_point = static_cast(lhs->zero_point()); + uint8_t rhs_zero_point = static_cast(rhs->zero_point()); + + const uint8_t *lhs_data = lhs->data(); + const uint8_t *rhs_data = rhs->data(); + const int32_t *bias_data = nullptr; + if (bias) { + bias_data = bias->data(); + } + OUTPUT_TYPE *output_data = output->mutable_data(); + + int32x4_t voutput_multiplier = vdupq_n_s32(output_multiplier); + int32x4_t voutput_shift_left = vdupq_n_s32(-output_shift); + + uint8x8_t + vlhs_zero_point = vdup_n_u8(lhs_zero_point); + uint8x8_t + vrhs_zero_point = vdup_n_u8(rhs_zero_point); + + const uint8_t + *lhs_ptr = lhs_data + + static_cast(lhs_batched) * b * lhs_height * lhs_width + + lhs_width * h_block_idx * h_block_size; + const uint8_t *rhs_ptr = rhs_data + b * lhs_width; + OUTPUT_TYPE + *ret_ptr = output_data + b * lhs_height + h_block_idx * h_block_size; + + const index_t h_block_len = + std::min(h_block_size, lhs_height - h_block_idx * h_block_size); + const index_t h_offset = h_block_idx * h_block_size; + + if (h_block_len == 4) { + int32x4_t vo0 = vdupq_n_s32(0); + int32x4_t vo1 = vdupq_n_s32(0); + int32x4_t vo2 = vdupq_n_s32(0); + int32x4_t vo3 = vdupq_n_s32(0); + + index_t r_w_block_count = w_block_count; + // just make compiler happy + MACE_UNUSED(r_w_block_count); + + // Register layout: (4x16) x (16x1) + // + // +----+ + // |d16 | + // | . | + // | . | + // | . | + // Rhs +----+ + // |d17 | + // | . | + // | . | + // | . | + // +----+ + // |d18 | + // | . | + // | . | + // | . | + // +----+ + // |d19 | + // | . | + // | . | + // | . | + // +----+ + // + // | | + // + // Lhs | | + // + // +--------+--------+--------+--------+ - - - - +----+ + // | d0 ... | d1 ... | d2 ... | d3 ... | |vo0 | + // | d4 ... | d5 ... | d6 ... | d7 ... | |vo1 | + // | d8 ... | d9 ... | d10... | d11... | |vo2 | + // | d12... | d13... | d14... | d15... | |vo3 | + // +--------+--------+--------+--------+ - - - - +----+ + // + // Accumulator + // + +#if not defined(__aarch64__) + asm volatile( + "cmp %[r_w_block_count], #0\n" + "beq 0f\n" + + "mov r0, %[rhs_ptr]\n" + "mov r1, %[lhs_ptr]\n" + "add r2, r1, %[lhs_width]\n" + "add r3, r2, %[lhs_width]\n" + "add r4, r3, %[lhs_width]\n" + + "vdup.u8 d20, %[rhs_zero_point]\n" + "vdup.u8 d21, %[lhs_zero_point]\n" + + // prelogue + "vld1.8 d16, [r0]!\n" + "vld1.8 d18, [r0]!\n" + + "vld1.8 d0, [r1]!\n" + "vld1.8 d2, [r1]!\n" + "vld1.8 d4, [r2]!\n" + "vld1.8 d6, [r2]!\n" + "vld1.8 d8, [r3]!\n" + "vld1.8 d10, [r3]!\n" + "vld1.8 d12, [r4]!\n" + "vld1.8 d14, [r4]!\n" + + "subs %[r_w_block_count], #1\n" + "beq 1f\n" + + "2: \n" + "vsubl.u8 q8, d16, d20\n" + "vsubl.u8 q9, d18, d20\n" + + "vsubl.u8 q0, d0, d21\n" + "vsubl.u8 q1, d2, d21\n" + "vsubl.u8 q2, d4, d21\n" + "vsubl.u8 q3, d6, d21\n" + "vsubl.u8 q4, d8, d21\n" + "vsubl.u8 q5, d10, d21\n" + "vsubl.u8 q6, d12, d21\n" + "vsubl.u8 q7, d14, d21\n" + + "vmlal.s16 %q[vo0], d0, d16\n" + "vmlal.s16 %q[vo1], d4, d16\n" + "vmlal.s16 %q[vo2], d8, d16\n" + "vmlal.s16 %q[vo3], d12, d16\n" + + "vld1.8 d0, [r1]!\n" + "vld1.8 d4, [r2]!\n" + "vld1.8 d8, [r3]!\n" + "vld1.8 d12, [r4]!\n" + "vld1.8 d16, [r0]!\n" + + "vmlal.s16 %q[vo0], d2, d18\n" + "vmlal.s16 %q[vo1], d6, d18\n" + "vmlal.s16 %q[vo2], d10, d18\n" + "vmlal.s16 %q[vo3], d14, d18\n" + + "vld1.8 d2, [r1]!\n" + "vld1.8 d6, [r2]!\n" + "vld1.8 d10, [r3]!\n" + "vld1.8 d14, [r4]!\n" + "vld1.8 d18, [r0]!\n" + + "vmlal.s16 %q[vo0], d1, d17\n" + "vmlal.s16 %q[vo1], d5, d17\n" + "vmlal.s16 %q[vo2], d9, d17\n" + "vmlal.s16 %q[vo3], d13, d17\n" + + "subs %[r_w_block_count], #1\n" + "vmlal.s16 %q[vo0], d3, d19\n" + "vmlal.s16 %q[vo1], d7, d19\n" + "vmlal.s16 %q[vo2], d11, d19\n" + "vmlal.s16 %q[vo3], d15, d19\n" + + "bne 2b\n" + + // prologue + "1:\n" + "vsubl.u8 q8, d16, d20\n" + "vsubl.u8 q9, d18, d20\n" + + "vsubl.u8 q0, d0, d21\n" + "vsubl.u8 q1, d2, d21\n" + "vsubl.u8 q2, d4, d21\n" + "vsubl.u8 q3, d6, d21\n" + "vsubl.u8 q4, d8, d21\n" + "vsubl.u8 q5, d10, d21\n" + "vsubl.u8 q6, d12, d21\n" + "vsubl.u8 q7, d14, d21\n" + + "vmlal.s16 %q[vo0], d0, d16\n" + "vmlal.s16 %q[vo1], d4, d16\n" + "vmlal.s16 %q[vo2], d8, d16\n" + "vmlal.s16 %q[vo3], d12, d16\n" + + "vmlal.s16 %q[vo0], d1, d17\n" + "vmlal.s16 %q[vo1], d5, d17\n" + "vmlal.s16 %q[vo2], d9, d17\n" + "vmlal.s16 %q[vo3], d13, d17\n" + + "vmlal.s16 %q[vo0], d2, d18\n" + "vmlal.s16 %q[vo1], d6, d18\n" + "vmlal.s16 %q[vo2], d10, d18\n" + "vmlal.s16 %q[vo3], d14, d18\n" + + "vmlal.s16 %q[vo0], d3, d19\n" + "vmlal.s16 %q[vo1], d7, d19\n" + "vmlal.s16 %q[vo2], d11, d19\n" + "vmlal.s16 %q[vo3], d15, d19\n" + + "0:\n" + : // outputs + [vo0] "+w"(vo0), + [vo1] "+w"(vo1), + [vo2] "+w"(vo2), + [vo3] "+w"(vo3), + [r_w_block_count] "+r"(r_w_block_count) + : // inputs + [lhs_ptr] "r"(lhs_ptr), [rhs_ptr] "r"(rhs_ptr), + [lhs_width] "r"(lhs_width), + [lhs_zero_point] "r"(lhs_zero_point), + [rhs_zero_point] "r"(rhs_zero_point) + : // clobbers + "cc", "memory", "r0", "r1", "r2", "r3", "r4", + "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", "d9", "d10", + "d11", "d12", "d13", "d14", "d15", "d16", "d17", "d18", "d19", "d20", + "d21"); + + lhs_ptr += w_block_count * w_block_size; + rhs_ptr += w_block_count * w_block_size; +#else + for (index_t w_block_index = 0; w_block_index < w_block_count; + ++w_block_index) { + uint8x8_t vr0 = vld1_u8(rhs_ptr); + int16x8_t + vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point)); + uint8x8_t vr0n = vld1_u8(rhs_ptr + 8); + int16x8_t + vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point)); + + uint8x8_t vl0 = vld1_u8(lhs_ptr); + int16x8_t + vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point)); + uint8x8_t vl0n = vld1_u8(lhs_ptr + 8); + int16x8_t + vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point)); + + vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0)); + vo0 = vmlal_high_s16(vo0, vxl0, vxr0); + vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n)); + vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n); + + const uint8_t *lhs_ptr1 = lhs_ptr + lhs_width; + + uint8x8_t vl1 = vld1_u8(lhs_ptr1); + int16x8_t + vxl1 = vreinterpretq_s16_u16(vsubl_u8(vl1, vlhs_zero_point)); + uint8x8_t vl1n = vld1_u8(lhs_ptr1 + 8); + int16x8_t + vxl1n = vreinterpretq_s16_u16(vsubl_u8(vl1n, vlhs_zero_point)); + + vo1 = vmlal_s16(vo1, vget_low_s16(vxl1), vget_low_s16(vxr0)); + vo1 = vmlal_high_s16(vo1, vxl1, vxr0); + vo1 = vmlal_s16(vo1, vget_low_s16(vxl1n), vget_low_s16(vxr0n)); + vo1 = vmlal_high_s16(vo1, vxl1n, vxr0n); + + const uint8_t *lhs_ptr2 = lhs_ptr1 + lhs_width; + + uint8x8_t vl2 = vld1_u8(lhs_ptr2); + int16x8_t + vxl2 = vreinterpretq_s16_u16(vsubl_u8(vl2, vlhs_zero_point)); + uint8x8_t vl2n = vld1_u8(lhs_ptr2 + 8); + int16x8_t + vxl2n = vreinterpretq_s16_u16(vsubl_u8(vl2n, vlhs_zero_point)); + + vo2 = vmlal_s16(vo2, vget_low_s16(vxl2), vget_low_s16(vxr0)); + vo2 = vmlal_high_s16(vo2, vxl2, vxr0); + vo2 = vmlal_s16(vo2, vget_low_s16(vxl2n), vget_low_s16(vxr0n)); + vo2 = vmlal_high_s16(vo2, vxl2n, vxr0n); + + const uint8_t *lhs_ptr3 = lhs_ptr2 + lhs_width; + + uint8x8_t vl3 = vld1_u8(lhs_ptr3); + int16x8_t + vxl3 = vreinterpretq_s16_u16(vsubl_u8(vl3, vlhs_zero_point)); + uint8x8_t vl3n = vld1_u8(lhs_ptr3 + 8); + int16x8_t + vxl3n = vreinterpretq_s16_u16(vsubl_u8(vl3n, vlhs_zero_point)); + + vo3 = vmlal_s16(vo3, vget_low_s16(vxl3), vget_low_s16(vxr0)); + vo3 = vmlal_high_s16(vo3, vxl3, vxr0); + vo3 = vmlal_s16(vo3, vget_low_s16(vxl3n), vget_low_s16(vxr0n)); + vo3 = vmlal_high_s16(vo3, vxl3n, vxr0n); + + lhs_ptr += 16; + rhs_ptr += 16; + } +#endif // __aarch64__ + int32x4_t vo = {vaddvq_s32(vo0), + vaddvq_s32(vo1), + vaddvq_s32(vo2), + vaddvq_s32(vo3)}; + + for (index_t w = 0; w < w_remain; ++w) { + vo[0] += + (lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point); + vo[1] += (lhs_ptr[lhs_width] - lhs_zero_point) + * (rhs_ptr[0] - rhs_zero_point); + vo[2] += (lhs_ptr[lhs_width * 2] - lhs_zero_point) + * (rhs_ptr[0] - rhs_zero_point); + vo[3] += (lhs_ptr[lhs_width * 3] - lhs_zero_point) + * (rhs_ptr[0] - rhs_zero_point); + ++lhs_ptr; + ++rhs_ptr; + } + + int32x4_t vbias = vdupq_n_s32(0); + if (bias) { + vbias = vld1q_s32(bias_data + h_offset); + } + vo = vaddq_s32(vo, vbias); + + if (is_output_type_uint8) { + int32x4_t vo_mul = vqrdmulhq_s32(vo, voutput_multiplier); + int32x4_t + fixup = vshrq_n_s32(vandq_s32(vo_mul, voutput_shift_left), 31); + int32x4_t fixed_up_x = vqaddq_s32(vo_mul, fixup); + int32x4_t + vo_rescale_int32 = vrshlq_s32(fixed_up_x, voutput_shift_left); + + int16x4_t vo_rescale_int16 = vqmovn_s32(vo_rescale_int32); + uint8x8_t vo_rescale_uint8 = + vqmovun_s16(vcombine_s16(vo_rescale_int16, vo_rescale_int16)); + + ret_ptr[0] = vo_rescale_uint8[0]; + ret_ptr[1] = vo_rescale_uint8[1]; + ret_ptr[2] = vo_rescale_uint8[2]; + ret_ptr[3] = vo_rescale_uint8[3]; + } else { + ret_ptr[0] = vo[0]; + ret_ptr[1] = vo[1]; + ret_ptr[2] = vo[2]; + ret_ptr[3] = vo[3]; + } + } else { // h_block_len < 4 + // TODO(liyin): handle here case by case (1,2,3) to accelerate + const uint8_t *tmp_lhs_ptr = lhs_ptr; + const uint8_t *tmp_rhs_ptr = rhs_ptr; + for (index_t h = 0; h < h_block_len; ++h) { + lhs_ptr = tmp_lhs_ptr + h * lhs_width; + rhs_ptr = tmp_rhs_ptr; + int32x4_t vo0 = vdupq_n_s32(0); + for (index_t w = 0; w < w_block_count; ++w) { + uint8x8_t vr0 = vld1_u8(rhs_ptr); + int16x8_t + vxr0 = vreinterpretq_s16_u16(vsubl_u8(vr0, vrhs_zero_point)); + uint8x8_t vr0n = vld1_u8(rhs_ptr + 8); + int16x8_t + vxr0n = vreinterpretq_s16_u16(vsubl_u8(vr0n, vrhs_zero_point)); + + uint8x8_t vl0 = vld1_u8(lhs_ptr); + int16x8_t + vxl0 = vreinterpretq_s16_u16(vsubl_u8(vl0, vlhs_zero_point)); + uint8x8_t vl0n = vld1_u8(lhs_ptr + 8); + int16x8_t + vxl0n = vreinterpretq_s16_u16(vsubl_u8(vl0n, vlhs_zero_point)); + + vo0 = vmlal_s16(vo0, vget_low_s16(vxl0), vget_low_s16(vxr0)); + vo0 = vmlal_high_s16(vo0, vxl0, vxr0); + vo0 = vmlal_s16(vo0, vget_low_s16(vxl0n), vget_low_s16(vxr0n)); + vo0 = vmlal_high_s16(vo0, vxl0n, vxr0n); + + lhs_ptr += 16; + rhs_ptr += 16; + } // w + int32_t s0 = vaddvq_s32(vo0) + (bias ? bias_data[h_offset + h] : 0); + for (index_t w = 0; w < w_remain; ++w) { + s0 += (lhs_ptr[0] - lhs_zero_point) * (rhs_ptr[0] - rhs_zero_point); + ++lhs_ptr; + ++rhs_ptr; + } // w + + if (is_output_type_uint8) { + ret_ptr[h] = + Saturate(std::roundf(s0 * output_multiplier_float)); + } else { + ret_ptr[h] = s0; + } + } // h + } // if + } // h_block_idx + } // b + + return MaceStatus::MACE_SUCCESS; +} + +template +class Gemv; +template +class Gemv; + +} // namespace q8 +} // namespace arm +} // namespace ops +} // namespace mace + +#if defined(vmlal_high_s16) +#undef vmlal_high_s16 +#undef vaddvq_s32 +#endif diff --git a/mace/ops/arm/q8/gemv.h b/mace/ops/arm/q8/gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..2269ec98cc7e05b845b7f7af949d1133afe27414 --- /dev/null +++ b/mace/ops/arm/q8/gemv.h @@ -0,0 +1,50 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_Q8_GEMV_H_ +#define MACE_OPS_ARM_Q8_GEMV_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" + +namespace mace { +namespace ops { +namespace arm { +namespace q8 { + +template +class Gemv { + public: + Gemv() {} + ~Gemv() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output); +}; + +} // namespace q8 +} // namespace arm +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_Q8_GEMV_H_ diff --git a/mace/ops/arm/q8/gemv_test.cc b/mace/ops/arm/q8/gemv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..10cab2165b8e92775ce9e6a7edbd49f315f15f37 --- /dev/null +++ b/mace/ops/arm/q8/gemv_test.cc @@ -0,0 +1,183 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" +#include "mace/ops/arm/q8/gemv.h" +#include "mace/ops/ref/gemv.h" +#include "mace/ops/testing/test_utils.h" + +namespace mace { +namespace ops { +namespace test { + +void TestGemvInt32(const index_t batch, + const index_t height, + const index_t width, + const bool lhs_batched) { + Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8); + Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8); + Tensor bias(GetCPUAllocator(), DataType::DT_INT32); + Tensor output(GetCPUAllocator(), DataType::DT_INT32); + lhs.SetScale(0.5); + rhs.SetScale(0.3); + lhs.SetZeroPoint(0); + rhs.SetZeroPoint(0); + lhs.Resize({lhs_batched ? batch : 1, height, width}); + rhs.Resize({batch, width}); + bias.Resize({height}); + output.Resize({batch, height}); + { + Tensor::MappingGuard lhs_guard(&lhs); + Tensor::MappingGuard rhs_guard(&rhs); + Tensor::MappingGuard bias_guard(&bias); + uint8_t *lhs_data = lhs.mutable_data(); + uint8_t *rhs_data = rhs.mutable_data(); + int32_t *bias_data = bias.mutable_data(); + GenerateRandomIntTypeData(lhs.shape(), lhs_data); + GenerateRandomIntTypeData(rhs.shape(), rhs_data); + GenerateRandomIntTypeData(bias.shape(), bias_data); + } + + mace::ops::arm::q8::Gemv gemv; + gemv.Compute(nullptr, + &lhs, + &rhs, + &bias, + batch, + height, + width, + lhs_batched, + &output); + + Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); + expected_output.Resize({batch, height}); + mace::ops::ref::Gemv gemv_ref; + gemv_ref.Compute(nullptr, + &lhs, + &rhs, + &bias, + batch, + height, + width, + lhs_batched, + &expected_output); + + Tensor::MappingGuard output_guard(&output); + Tensor::MappingGuard expected_guard(&expected_output); + const int32_t *output_data = output.data(); + const int32_t *expected_data = expected_output.data(); + + for (index_t i = 0; i < output.size(); ++i) { + EXPECT_EQ(expected_data[i], output_data[i]); + } +} + +void TestGemvUint8(const index_t batch, + const index_t height, + const index_t width, + const bool lhs_batched) { + Tensor lhs(GetCPUAllocator(), DataType::DT_UINT8); + Tensor rhs(GetCPUAllocator(), DataType::DT_UINT8); + Tensor bias(GetCPUAllocator(), DataType::DT_INT32); + Tensor output(GetCPUAllocator(), DataType::DT_UINT8); + lhs.SetScale(0.5); + rhs.SetScale(0.3); + output.SetScale(0.6); + lhs.SetZeroPoint(23); + rhs.SetZeroPoint(45); + output.SetZeroPoint(57); + lhs.Resize({batch, height, width}); + rhs.Resize({batch, width}); + bias.Resize({height}); + output.Resize({batch, height}); + { + Tensor::MappingGuard lhs_guard(&lhs); + Tensor::MappingGuard rhs_guard(&rhs); + Tensor::MappingGuard bias_guard(&bias); + + uint8_t *lhs_data = lhs.mutable_data(); + uint8_t *rhs_data = rhs.mutable_data(); + int32_t *bias_data = bias.mutable_data(); + GenerateRandomIntTypeData(lhs.shape(), lhs_data); + GenerateRandomIntTypeData(rhs.shape(), rhs_data); + GenerateRandomIntTypeData(bias.shape(), bias_data); + } + + mace::ops::arm::q8::Gemv gemv; + gemv.Compute(nullptr, + &lhs, + &rhs, + &bias, + batch, + height, + width, + lhs_batched, + &output); + + Tensor expected_output(GetCPUAllocator(), DataType::DT_INT32); + expected_output.SetScale(0.6); + expected_output.SetZeroPoint(57); + expected_output.Resize({batch, height}); + mace::ops::ref::Gemv gemv_ref; + gemv_ref.Compute(nullptr, + &lhs, + &rhs, + &bias, + batch, + height, + width, + lhs_batched, + &expected_output); + + Tensor::MappingGuard output_guard(&output); + Tensor::MappingGuard expected_guard(&expected_output); + const uint8_t *output_data = output.data(); + const uint8_t *expected_data = expected_output.data(); + + for (index_t i = 0; i < output.size(); ++i) { + EXPECT_EQ(expected_data[i], output_data[i]); + } +} + +TEST(ArmGemv, TestGemvInt32) { + TestGemvInt32(1, 16, 4, true); + TestGemvInt32(1, 16, 256, true); + TestGemvInt32(2, 16, 256, true); + TestGemvInt32(3, 63, 257, true); + + TestGemvInt32(1, 16, 4, false); + TestGemvInt32(1, 16, 256, false); + TestGemvInt32(2, 16, 256, false); + TestGemvInt32(3, 63, 257, false); +} + +TEST(ArmGemv, TestGemvUint8) { + TestGemvUint8(1, 16, 4, true); + TestGemvUint8(1, 16, 256, true); + TestGemvUint8(2, 16, 256, true); + TestGemvUint8(3, 63, 257, true); + + TestGemvUint8(1, 16, 4, false); + TestGemvUint8(1, 16, 256, false); + TestGemvUint8(2, 16, 256, false); + TestGemvUint8(3, 63, 257, false); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/common/activation_type.h b/mace/ops/common/activation_type.h new file mode 100644 index 0000000000000000000000000000000000000000..de8f6e8b7cef4697c61749edcb88039c0f788667 --- /dev/null +++ b/mace/ops/common/activation_type.h @@ -0,0 +1,34 @@ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_COMMON_ACTIVATION_TYPE_H_ +#define MACE_OPS_COMMON_ACTIVATION_TYPE_H_ + +namespace mace { +namespace ops { + +enum ActivationType { + NOOP = 0, + RELU = 1, + RELUX = 2, + PRELU = 3, + TANH = 4, + SIGMOID = 5, + LEAKYRELU = 6, +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_COMMON_ACTIVATION_TYPE_H_ diff --git a/mace/ops/conv_pool_2d_util.cc b/mace/ops/common/conv_pool_2d_util.cc similarity index 99% rename from mace/ops/conv_pool_2d_util.cc rename to mace/ops/common/conv_pool_2d_util.cc index 92d88d446e52d6870a625c97a9af8942a38b1b42..8634cf2cb8333d03a97b131692c84d5f5249cab5 100644 --- a/mace/ops/conv_pool_2d_util.cc +++ b/mace/ops/common/conv_pool_2d_util.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include #include diff --git a/mace/ops/conv_pool_2d_util.h b/mace/ops/common/conv_pool_2d_util.h similarity index 97% rename from mace/ops/conv_pool_2d_util.h rename to mace/ops/common/conv_pool_2d_util.h index a644f5f355e6bca654c71c7c7febc57b5c280e39..db359ee92b02a88c48555ada851047f3ebe7f2e5 100644 --- a/mace/ops/conv_pool_2d_util.h +++ b/mace/ops/common/conv_pool_2d_util.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_CONV_POOL_2D_UTIL_H_ -#define MACE_OPS_CONV_POOL_2D_UTIL_H_ +#ifndef MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_ +#define MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_ #include "mace/core/tensor.h" @@ -116,4 +116,4 @@ MaceStatus ConstructNHWCInputWithPadding(const Tensor *input, } // namespace ops } // namespace mace -#endif // MACE_OPS_CONV_POOL_2D_UTIL_H_ +#endif // MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_ diff --git a/mace/ops/common/transpose.cc b/mace/ops/common/transpose.cc new file mode 100644 index 0000000000000000000000000000000000000000..469456a1c4424445ba836261c0f9bd71db878155 --- /dev/null +++ b/mace/ops/common/transpose.cc @@ -0,0 +1,218 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/common/transpose.h" + +#include + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#include "mace/core/types.h" +#include "mace/utils/logging.h" + +namespace mace { +namespace ops { + +namespace { +void TransposeNHWCToNCHWC3(const float *input, + float *output, + const index_t height, + const index_t width) { + index_t image_size = height * width; + +#pragma omp parallel for + for (index_t h = 0; h < height; ++h) { + index_t in_offset = h * width * 3; + index_t out_offset = h * width; + +#if defined(MACE_ENABLE_NEON) + index_t w; + for (w = 0; w + 3 < width; w += 4) { + float32x4x3_t vi = vld3q_f32(input + in_offset); + vst1q_f32(output + out_offset, vi.val[0]); + vst1q_f32(output + out_offset + image_size, vi.val[1]); + vst1q_f32(output + out_offset + image_size * 2, vi.val[2]); + + in_offset += 12; + out_offset += 4; + } + for (; w < width; ++w) { + for (index_t c = 0; c < 3; ++c) { + output[h * width + image_size * c + w] = + input[h * width * 3 + w * 3 + c]; + } + } +#else + for (index_t w = 0; w < width; ++w) { + for (index_t c = 0; c < 3; ++c) { + output[out_offset + c * image_size + w] = input[in_offset + w * 3 + c]; + } + } +#endif + } +} + +void TransposeNCHWToNHWCC2(const float *input, + float *output, + const index_t height, + const index_t width) { + index_t image_size = height * width; +#pragma omp parallel for + for (index_t h = 0; h < height; ++h) { + index_t in_offset = h * width; + index_t out_offset = h * width * 2; + +#if defined(MACE_ENABLE_NEON) + index_t w; + for (w = 0; w + 3 < width; w += 4) { + float32x4_t vi0 = vld1q_f32(input + in_offset); + float32x4_t vi1 = vld1q_f32(input + in_offset + image_size); + float32x4x2_t vi = {vi0, vi1}; + vst2q_f32(output + out_offset, vi); + in_offset += 4; + out_offset += 8; + } + for (; w < width; ++w) { + for (index_t c = 0; c < 2; ++c) { + output[h * width * 2 + w * 2 + c] = + input[h * width + image_size * c + w]; + } + } +#else + for (index_t w = 0; w < width; ++w) { + for (index_t c = 0; c < 2; ++c) { + output[out_offset + w * 2 + c] = input[in_offset + c * image_size + w]; + } + } +#endif + } +} +} // namespace + +MaceStatus Transpose(const float *input, + const std::vector &input_shape, + const std::vector &dst_dims, + float *output) { + MACE_CHECK((input_shape.size() == 2 && dst_dims.size() == 2) || + (input_shape.size() == 4 && dst_dims.size() == 4), + "Only support 2D or 4D transpose"); + + std::vector output_shape; + for (size_t i = 0; i < dst_dims.size(); ++i) { + output_shape.push_back(input_shape[dst_dims[i]]); + } + + if (input_shape.size() == 2) { + MACE_CHECK(dst_dims[0] == 1 && dst_dims[1] == 0, "no need transform"); + index_t height = input_shape[0]; + index_t width = input_shape[1]; + index_t stride_i = height; + index_t stride_j = width; + index_t tile_size = height > 512 || width > 512 ? 64 : 32; +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < height; i += tile_size) { + for (index_t j = 0; j < width; j += tile_size) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + output[tile_j * stride_i + tile_i] = + input[tile_i * stride_j + tile_j]; + } + } + } + } + } else if (input_shape.size() == 4) { + std::vector transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2}; + std::vector transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1}; + index_t batch_size = input_shape[1] * input_shape[2] * input_shape[3]; + + if (dst_dims == transpose_order_from_NHWC_to_NCHW && input_shape[3] == 3) { + for (index_t b = 0; b < input_shape[0]; ++b) { + TransposeNHWCToNCHWC3(input + b * batch_size, + output + b * batch_size, + input_shape[1], + input_shape[2]); + } + } else if (dst_dims == transpose_order_from_NCHW_to_NHWC + && input_shape[1] == 2) { + for (index_t b = 0; b < input_shape[0]; ++b) { + TransposeNCHWToNHWCC2(input + b * batch_size, + output + b * batch_size, + input_shape[2], + input_shape[3]); + } + } else if (dst_dims == std::vector{0, 2, 1, 3}) { + index_t height = input_shape[1]; + index_t width = input_shape[2]; + index_t channel = input_shape[3]; + index_t channel_raw_size = channel * sizeof(float); + index_t stride_i = height; + index_t stride_j = width; + index_t tile_size = std::max(static_cast(1), + static_cast(std::sqrt( + 8 * 1024 / channel))); +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < height; i += tile_size) { + for (index_t j = 0; j < width; j += tile_size) { + index_t end_i = std::min(i + tile_size, height); + index_t end_j = std::min(j + tile_size, width); + for (index_t tile_i = i; tile_i < end_i; ++tile_i) { + for (index_t tile_j = j; tile_j < end_j; ++tile_j) { + memcpy(output + (tile_j * stride_i + tile_i) * channel, + input + (tile_i * stride_j + tile_j) * channel, + channel_raw_size); + } + } + } + } + } else { + std::vector + in_stride{input_shape[1] * input_shape[2] * input_shape[3], + input_shape[2] * input_shape[3], input_shape[3], 1}; + std::vector + out_stride{output_shape[1] * output_shape[2] * output_shape[3], + output_shape[2] * output_shape[3], output_shape[3], 1}; + + std::vector idim(4, 0); + std::vector odim(4, 0); + for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { + for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { + for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { + for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { + idim[dst_dims[0]] = odim[0]; + idim[dst_dims[1]] = odim[1]; + idim[dst_dims[2]] = odim[2]; + idim[dst_dims[3]] = odim[3]; + + output[odim[0] * out_stride[0] + odim[1] * out_stride[1] + + odim[2] * out_stride[2] + odim[3]] = + input[idim[0] * in_stride[0] + idim[1] * in_stride[1] + + idim[2] * in_stride[2] + idim[3]]; + } + } + } + } + } + } else { + MACE_NOT_IMPLEMENTED; + } + + return MaceStatus::MACE_SUCCESS; +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/transpose.h b/mace/ops/common/transpose.h similarity index 89% rename from mace/ops/transpose.h rename to mace/ops/common/transpose.h index 9ab09d1753577b3ee1a67c01034dedf1ad8723d6..5f8e23490698ab71439d2486bc32d269e8d5ee0b 100644 --- a/mace/ops/transpose.h +++ b/mace/ops/common/transpose.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_TRANSPOSE_H_ -#define MACE_OPS_TRANSPOSE_H_ +#ifndef MACE_OPS_COMMON_TRANSPOSE_H_ +#define MACE_OPS_COMMON_TRANSPOSE_H_ #include @@ -30,4 +30,4 @@ MaceStatus Transpose(const float *input, } // namespace ops } // namespace mace -#endif // MACE_OPS_TRANSPOSE_H_ +#endif // MACE_OPS_COMMON_TRANSPOSE_H_ diff --git a/mace/ops/conv_2d.cc b/mace/ops/conv_2d.cc index 0cfcf64f9211c602dd66ca61f1e3da4ab45d39b7..d3c0697378b6d7a098c6ebbcc2888fe9f8d9e668 100644 --- a/mace/ops/conv_2d.cc +++ b/mace/ops/conv_2d.cc @@ -30,7 +30,7 @@ #include "mace/ops/arm/conv_2d_neon.h" #include "mace/ops/arm/conv_winograd.h" #include "mace/ops/conv_pool_2d_base.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/utils/utils.h" #ifdef MACE_ENABLE_QUANTIZE diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 1cfe048dabdee5be64d726aa9cb1c87c02385a99..a930ea909d507ca4b48c21ec8420be05fb617092 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -16,7 +16,7 @@ #include "mace/benchmark/statistics.h" #include "mace/core/testing/test_benchmark.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 6965ca121d92c2f9609b7081e85695b9a0a0a42a..59eb532bcb241fd5a484766077384a1e771ef721 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -15,7 +15,7 @@ #include #include -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/conv_pool_2d_base.h b/mace/ops/conv_pool_2d_base.h index 99fe1edf4072e9455a2d026eaadfcf60ec2a3d79..b5ad48aea307a138fbbea234b6f44465055817c4 100644 --- a/mace/ops/conv_pool_2d_base.h +++ b/mace/ops/conv_pool_2d_base.h @@ -18,7 +18,7 @@ #include #include "mace/core/operator.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { namespace ops { diff --git a/mace/ops/deconv_2d.h b/mace/ops/deconv_2d.h index 9da5213ff1fe1b1292a1f0d27b4735d861fe8a84..008c6a5b5ea2cb9cc14c7c40940206e81c4f7aed 100644 --- a/mace/ops/deconv_2d.h +++ b/mace/ops/deconv_2d.h @@ -22,7 +22,7 @@ #include "mace/core/operator.h" #include "mace/core/types.h" #include "mace/ops/activation.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { namespace ops { diff --git a/mace/ops/deconv_2d_benchmark.cc b/mace/ops/deconv_2d_benchmark.cc index 2c0f3018e6340f4f6d00bfb1621fdf5e708e5484..0144bc595c04ab7decd2bd543846b8b575f4c55c 100644 --- a/mace/ops/deconv_2d_benchmark.cc +++ b/mace/ops/deconv_2d_benchmark.cc @@ -16,7 +16,7 @@ #include "mace/benchmark/statistics.h" #include "mace/core/testing/test_benchmark.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/deconv_2d_test.cc b/mace/ops/deconv_2d_test.cc index eb2719d2b90c42472cc52c1617e3ad1945ef9c6f..d8a1c621a49656a845319e1c849b9037e618fec4 100644 --- a/mace/ops/deconv_2d_test.cc +++ b/mace/ops/deconv_2d_test.cc @@ -16,7 +16,7 @@ #include #include "mace/ops/deconv_2d.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/depthwise_conv2d_benchmark.cc b/mace/ops/depthwise_conv2d_benchmark.cc index d159e90b4f8c2c9f6077b57ea5c02485024b9051..f0adb412fe7afcc86b848963566e193553160e9b 100644 --- a/mace/ops/depthwise_conv2d_benchmark.cc +++ b/mace/ops/depthwise_conv2d_benchmark.cc @@ -16,7 +16,7 @@ #include "mace/benchmark/statistics.h" #include "mace/core/testing/test_benchmark.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 321d6456dcff263d4433d1b7d4c1db909c6fa34d..72a50f24ce868da3ab5344062e3fa5ebeefbda2f 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/fully_connected.cc b/mace/ops/fully_connected.cc index 352d3e567c8be8c59db5ba497b6bb15b49a98654..185d627897eb7c23018f3d19caa49ab80da9c4f7 100644 --- a/mace/ops/fully_connected.cc +++ b/mace/ops/fully_connected.cc @@ -20,13 +20,19 @@ #include "mace/core/operator.h" #include "mace/core/tensor.h" #include "mace/ops/activation.h" -#include "mace/ops/gemm.h" + +#ifdef MACE_ENABLE_NEON + +#include "mace/ops/arm/fp32/gemv.h" #ifdef MACE_ENABLE_QUANTIZE -#include "mace/ops/gemmlowp_util.h" -#include "mace/ops/quantization_util.h" +#include "mace/ops/arm/q8/gemv.h" #endif // MACE_ENABLE_QUANTIZE +#else +#include "mace/ops/ref/gemv.h" +#endif // MACE_ENABLE_NEON + #ifdef MACE_ENABLE_OPENCL #include "mace/ops/opencl/buffer_transformer.h" #include "mace/ops/opencl/image/fully_connected.h" @@ -41,10 +47,10 @@ class FullyConnectedOpBase : public Operation { : Operation(context), activation_(ops::StringToActivationType( Operation::GetOptionalArg("activation", - "NOOP"))), + "NOOP"))), relux_max_limit_(Operation::GetOptionalArg("max_limit", 0.0f)), leakyrelu_coefficient_(Operation::GetOptionalArg( - "leakyrelu_coefficient", 0.0f)) {} + "leakyrelu_coefficient", 0.0f)) {} protected: const ActivationType activation_; const float relux_max_limit_; @@ -54,10 +60,10 @@ class FullyConnectedOpBase : public Operation { MACE_OP_OUTPUT_TAGS(OUTPUT); }; -template +template class FullyConnectedOp; -template <> +template<> class FullyConnectedOp : public FullyConnectedOpBase { public: explicit FullyConnectedOp(OpConstructContext *context) @@ -84,38 +90,37 @@ class FullyConnectedOp : public FullyConnectedOpBase { } std::vector output_shape = {input->dim(0), weight->dim(0), 1, 1}; MACE_RETURN_IF_ERROR(output->Resize(output_shape)); - const index_t N = output->dim(0); + const index_t batch = output->dim(0); const index_t input_size = weight->dim(1) * weight->dim(2) * weight->dim(3); const index_t output_size = weight->dim(0); - Tensor::MappingGuard guard_input(input); - Tensor::MappingGuard guard_weight(weight); + gemv_.Compute(context, + weight, + input, + bias, + batch, + output_size, + input_size, + false, + output); Tensor::MappingGuard guard_output(output); - const float *input_ptr = input->data(); - const float *weight_ptr = weight->data(); float *output_ptr = output->mutable_data(); - - Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr); - - if (bias) { - Tensor::MappingGuard guard_bias(bias); - const float *bias_ptr = bias == nullptr ? nullptr : bias->data(); - for (int i = 0; i < N; ++i) { - for (int j = 0; j < output_size; ++j) { - output_ptr[j + i * output_size] += bias_ptr[j]; - } - } - } - DoActivation(output_ptr, output_ptr, output->size(), activation_, relux_max_limit_, leakyrelu_coefficient_); return MaceStatus::MACE_SUCCESS; } + + private: +#ifdef MACE_ENABLE_NEON + arm::fp32::Gemv gemv_; +#else + ref::Gemv gemv_; +#endif // MACE_ENABLE_NEON }; #ifdef MACE_ENABLE_QUANTIZE -template <> +template<> class FullyConnectedOp : public FullyConnectedOpBase { public: @@ -145,44 +150,28 @@ class FullyConnectedOp std::vector output_shape = {input->dim(0), 1, 1, weight->dim(0)}; MACE_RETURN_IF_ERROR(output->Resize(output_shape)); - const int N = static_cast(output->dim(0)); + const int batch = static_cast(output->dim(0)); const int input_size = static_cast(weight->dim(1) * weight->dim(2) * weight->dim(3)); const int output_size = static_cast(weight->dim(0)); - - Tensor::MappingGuard guard_input(input); - Tensor::MappingGuard guard_weight(weight); - Tensor::MappingGuard guard_output(output); - auto input_ptr = input->data(); - auto weight_ptr = weight->data(); - auto output_ptr = output->mutable_data(); - auto bias_ptr = GetBiasData(bias, - input->scale(), - weight->scale(), - output_size, - &bias_); - - gemmlowp::MatrixMap - weight_matrix(weight_ptr, output_size, input_size); - gemmlowp::MatrixMap - input_matrix(input_ptr, input_size, N); - gemmlowp::MatrixMap - output_matrix(output_ptr, output_size, N); - - const auto &output_pipeline = GemmlowpOutputPipeline::Make( - bias_ptr, output_size, weight->scale(), input->scale(), output->scale(), - output->zero_point()); - - using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; - gemmlowp::GemmWithOutputPipeline( - gemm_context, weight_matrix, input_matrix, &output_matrix, - -weight->zero_point(), -input->zero_point(), output_pipeline); - + gemv_.Compute(context, + weight, + input, + bias, + batch, + output_size, + input_size, + false, + output); return MaceStatus::MACE_SUCCESS; } private: - std::vector bias_; +#ifdef MACE_ENABLE_NEON + ::mace::ops::arm::q8::Gemv gemv_; +#else + ref::Gemv gemv_; +#endif // MACE_ENABLE_NEON }; #endif // MACE_ENABLE_QUANTIZE diff --git a/mace/ops/gemm.cc b/mace/ops/gemm.cc deleted file mode 100644 index b24ccf95f235a5716a4f7466c6f0695fdc95f085..0000000000000000000000000000000000000000 --- a/mace/ops/gemm.cc +++ /dev/null @@ -1,1544 +0,0 @@ -// Copyright 2018 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "mace/core/tensor.h" -#include "mace/core/runtime/cpu/cpu_runtime.h" -#include "mace/ops/gemm.h" - -/** - * Gemm does fast matrix multiplications with batch. - * It is optimized for arm64-v8 and armeabi-v7a using neon. - * - * We adopt two-level tiling to make better use of l1 cache and register. - * For register tiling, function like GemmXYZ computes gemm for - * matrix[X, Y] * matrix[Y, Z] with all data being able to fit in register. - * For cache tiling, we try to compute one block of multiplication with - * two input matrices and one output matrix fit in l1 cache. - */ - -#if defined(MACE_ENABLE_NEON) -#include -#endif - -#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__) -#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3]) -#endif - -namespace mace { -namespace ops { - -namespace { -inline void GemmBlock(const float *A, - const float *B, - const index_t height, - const index_t K, - const index_t width, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *C) { - for (int i = 0; i < height; ++i) { - for (int j = 0; j < width; ++j) { - for (int k = 0; k < K; ++k) { - C[i * stride_c + j] += A[i * stride_a + k] * B[k * stride_b + j]; - } - } - } -} - -#if defined(MACE_ENABLE_NEON) -#if defined(__aarch64__) -#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \ - c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ - c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ - c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ - c##RC = vfmaq_laneq_f32(c##RC, b3, a##RA, 3); \ - c##RC = vfmaq_laneq_f32(c##RC, b4, a##RAN, 0); \ - c##RC = vfmaq_laneq_f32(c##RC, b5, a##RAN, 1); \ - c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ - c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); -#else -#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \ - c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \ - c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \ - c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \ - c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RA), 1); \ - c##RC = vmlaq_lane_f32(c##RC, b4, vget_low_f32(a##RAN), 0); \ - c##RC = vmlaq_lane_f32(c##RC, b5, vget_low_f32(a##RAN), 1); \ - c##RC = vmlaq_lane_f32(c##RC, b6, vget_high_f32(a##RAN), 0); \ - c##RC = vmlaq_lane_f32(c##RC, b7, vget_high_f32(a##RAN), 1); -#endif -#endif - -#if defined(MACE_ENABLE_NEON) -#if defined(__aarch64__) -#define MACE_GEMM_PART_CAL_4(RC) \ - c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \ - c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \ - c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \ - c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3); -#else -#define MACE_GEMM_PART_CAL_4(RC) \ - c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \ - c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \ - c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \ - c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1); -#endif -#endif - -inline void Gemm144(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - MACE_UNUSED(stride_a); - MACE_UNUSED(stride_c); - float32x4_t a0; - float32x4_t b0, b1, b2, b3; - float32x4_t c0; - - a0 = vld1q_f32(a_ptr); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - - c0 = vld1q_f32(c_ptr); - - MACE_GEMM_PART_CAL_4(0); - - vst1q_f32(c_ptr, c0); -#else - GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm244(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1; - float32x4_t b0, b1, b2, b3; - float32x4_t c0, c1; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 1 * stride_a); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - - MACE_GEMM_PART_CAL_4(0); - MACE_GEMM_PART_CAL_4(1); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); -#else - GemmBlock(a_ptr, b_ptr, 2, 4, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm344(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2; - float32x4_t b0, b1, b2, b3; - float32x4_t c0, c1, c2; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 1 * stride_a); - a2 = vld1q_f32(a_ptr + 2 * stride_a); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - - MACE_GEMM_PART_CAL_4(0); - MACE_GEMM_PART_CAL_4(1); - MACE_GEMM_PART_CAL_4(2); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); -#else - GemmBlock(a_ptr, b_ptr, 3, 4, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm444(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3; - float32x4_t b0, b1, b2, b3; - float32x4_t c0, c1, c2, c3; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 1 * stride_a); - a2 = vld1q_f32(a_ptr + 2 * stride_a); - a3 = vld1q_f32(a_ptr + 3 * stride_a); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - - MACE_GEMM_PART_CAL_4(0); - MACE_GEMM_PART_CAL_4(1); - MACE_GEMM_PART_CAL_4(2); - MACE_GEMM_PART_CAL_4(3); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); -#else - GemmBlock(a_ptr, b_ptr, 4, 4, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm544(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4; - float32x4_t b0, b1, b2, b3; - float32x4_t c0, c1, c2, c3, c4; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 1 * stride_a); - a2 = vld1q_f32(a_ptr + 2 * stride_a); - a3 = vld1q_f32(a_ptr + 3 * stride_a); - a4 = vld1q_f32(a_ptr + 4 * stride_a); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - c4 = vld1q_f32(c_ptr + 4 * stride_c); - - MACE_GEMM_PART_CAL_4(0); - MACE_GEMM_PART_CAL_4(1); - MACE_GEMM_PART_CAL_4(2); - MACE_GEMM_PART_CAL_4(3); - MACE_GEMM_PART_CAL_4(4); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); - vst1q_f32(c_ptr + 4 * stride_c, c4); -#else - GemmBlock(a_ptr, b_ptr, 5, 4, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm644(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5; - float32x4_t b0, b1, b2, b3; - float32x4_t c0, c1, c2, c3, c4, c5; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 1 * stride_a); - a2 = vld1q_f32(a_ptr + 2 * stride_a); - a3 = vld1q_f32(a_ptr + 3 * stride_a); - a4 = vld1q_f32(a_ptr + 4 * stride_a); - a5 = vld1q_f32(a_ptr + 5 * stride_a); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - c4 = vld1q_f32(c_ptr + 4 * stride_c); - c5 = vld1q_f32(c_ptr + 5 * stride_c); - - MACE_GEMM_PART_CAL_4(0); - MACE_GEMM_PART_CAL_4(1); - MACE_GEMM_PART_CAL_4(2); - MACE_GEMM_PART_CAL_4(3); - MACE_GEMM_PART_CAL_4(4); - MACE_GEMM_PART_CAL_4(5); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); - vst1q_f32(c_ptr + 4 * stride_c, c4); - vst1q_f32(c_ptr + 5 * stride_c, c5); -#else - GemmBlock(a_ptr, b_ptr, 6, 4, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm884(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, - a15; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1, c2, c3, c4, c5, c6, c7; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_a); - a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_a); - a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_a); - a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); - a12 = vld1q_f32(a_ptr + 6 * stride_a); - a13 = vld1q_f32(a_ptr + 6 * stride_a + 4); - a14 = vld1q_f32(a_ptr + 7 * stride_a); - a15 = vld1q_f32(a_ptr + 7 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - c4 = vld1q_f32(c_ptr + 4 * stride_c); - c5 = vld1q_f32(c_ptr + 5 * stride_c); - c6 = vld1q_f32(c_ptr + 6 * stride_c); - c7 = vld1q_f32(c_ptr + 7 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - MACE_GEMM_PART_CAL_8(2, 4, 5); - MACE_GEMM_PART_CAL_8(3, 6, 7); - MACE_GEMM_PART_CAL_8(4, 8, 9); - MACE_GEMM_PART_CAL_8(5, 10, 11); - MACE_GEMM_PART_CAL_8(6, 12, 13); - MACE_GEMM_PART_CAL_8(7, 14, 15); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); - vst1q_f32(c_ptr + 4 * stride_c, c4); - vst1q_f32(c_ptr + 5 * stride_c, c5); - vst1q_f32(c_ptr + 6 * stride_c, c6); - vst1q_f32(c_ptr + 7 * stride_c, c7); -#else - GemmBlock(a_ptr, b_ptr, 8, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm184(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - MACE_UNUSED(stride_a); - MACE_UNUSED(stride_c); - - float32x4_t a0, a1; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - - vst1q_f32(c_ptr, c0); -#else - GemmBlock(a_ptr, b_ptr, 1, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm284(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); -#else - GemmBlock(a_ptr, b_ptr, 2, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm384(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1, c2; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - MACE_GEMM_PART_CAL_8(2, 4, 5); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); -#else - GemmBlock(a_ptr, b_ptr, 3, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm484(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1, c2, c3; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_a); - a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - MACE_GEMM_PART_CAL_8(2, 4, 5); - MACE_GEMM_PART_CAL_8(3, 6, 7); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); -#else - GemmBlock(a_ptr, b_ptr, 4, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm584(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1, c2, c3, c4; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_a); - a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_a); - a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - c4 = vld1q_f32(c_ptr + 4 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - MACE_GEMM_PART_CAL_8(2, 4, 5); - MACE_GEMM_PART_CAL_8(3, 6, 7); - MACE_GEMM_PART_CAL_8(4, 8, 9); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); - vst1q_f32(c_ptr + 4 * stride_c, c4); -#else - GemmBlock(a_ptr, b_ptr, 5, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm684(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1, c2, c3, c4, c5; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_a); - a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_a); - a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_a); - a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - c4 = vld1q_f32(c_ptr + 4 * stride_c); - c5 = vld1q_f32(c_ptr + 5 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - MACE_GEMM_PART_CAL_8(2, 4, 5); - MACE_GEMM_PART_CAL_8(3, 6, 7); - MACE_GEMM_PART_CAL_8(4, 8, 9); - MACE_GEMM_PART_CAL_8(5, 10, 11); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); - vst1q_f32(c_ptr + 4 * stride_c, c4); - vst1q_f32(c_ptr + 5 * stride_c, c5); - -#else - GemmBlock(a_ptr, b_ptr, 6, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void Gemm784(const float *a_ptr, - const float *b_ptr, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *c_ptr) { -#if defined(MACE_ENABLE_NEON) - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13; - float32x4_t b0, b1, b2, b3, b4, b5, b6, b7; - float32x4_t c0, c1, c2, c3, c4, c5, c6; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_a); - a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_a); - a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_a); - a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); - a12 = vld1q_f32(a_ptr + 6 * stride_a); - a13 = vld1q_f32(a_ptr + 6 * stride_a + 4); - - b0 = vld1q_f32(b_ptr); - b1 = vld1q_f32(b_ptr + 1 * stride_b); - b2 = vld1q_f32(b_ptr + 2 * stride_b); - b3 = vld1q_f32(b_ptr + 3 * stride_b); - b4 = vld1q_f32(b_ptr + 4 * stride_b); - b5 = vld1q_f32(b_ptr + 5 * stride_b); - b6 = vld1q_f32(b_ptr + 6 * stride_b); - b7 = vld1q_f32(b_ptr + 7 * stride_b); - - c0 = vld1q_f32(c_ptr); - c1 = vld1q_f32(c_ptr + 1 * stride_c); - c2 = vld1q_f32(c_ptr + 2 * stride_c); - c3 = vld1q_f32(c_ptr + 3 * stride_c); - c4 = vld1q_f32(c_ptr + 4 * stride_c); - c5 = vld1q_f32(c_ptr + 5 * stride_c); - c6 = vld1q_f32(c_ptr + 6 * stride_c); - - MACE_GEMM_PART_CAL_8(0, 0, 1); - MACE_GEMM_PART_CAL_8(1, 2, 3); - MACE_GEMM_PART_CAL_8(2, 4, 5); - MACE_GEMM_PART_CAL_8(3, 6, 7); - MACE_GEMM_PART_CAL_8(4, 8, 9); - MACE_GEMM_PART_CAL_8(5, 10, 11); - MACE_GEMM_PART_CAL_8(6, 12, 13); - - vst1q_f32(c_ptr, c0); - vst1q_f32(c_ptr + 1 * stride_c, c1); - vst1q_f32(c_ptr + 2 * stride_c, c2); - vst1q_f32(c_ptr + 3 * stride_c, c3); - vst1q_f32(c_ptr + 4 * stride_c, c4); - vst1q_f32(c_ptr + 5 * stride_c, c5); - vst1q_f32(c_ptr + 6 * stride_c, c6); - -#else - GemmBlock(a_ptr, b_ptr, 7, 8, 4, stride_a, stride_b, stride_c, c_ptr); -#endif -} - -inline void GemmTile(const float *A, - const float *B, - const index_t height, - const index_t K, - const index_t width, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c, - float *C) { -#if defined(MACE_ENABLE_NEON) - index_t h = 0; - index_t w = 0; - index_t k = 0; -#if defined(__aarch64__) - int reg_height_tile = 8; - int reg_K_tile = 8; -#else - int reg_height_tile = 6; - int reg_K_tile = 4; -#endif - - for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) { - for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) { - const float *a_ptr = A + (h * stride_a + k); -#if defined(__aarch64__) && defined(__clang__) - int nw = width >> 2; - if (nw > 0) { - // load A - float32x4_t a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, - a14, a15; - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 4); - a2 = vld1q_f32(a_ptr + 1 * stride_a); - a3 = vld1q_f32(a_ptr + 1 * stride_a + 4); - a4 = vld1q_f32(a_ptr + 2 * stride_a); - a5 = vld1q_f32(a_ptr + 2 * stride_a + 4); - a6 = vld1q_f32(a_ptr + 3 * stride_a); - a7 = vld1q_f32(a_ptr + 3 * stride_a + 4); - a8 = vld1q_f32(a_ptr + 4 * stride_a); - a9 = vld1q_f32(a_ptr + 4 * stride_a + 4); - a10 = vld1q_f32(a_ptr + 5 * stride_a); - a11 = vld1q_f32(a_ptr + 5 * stride_a + 4); - a12 = vld1q_f32(a_ptr + 6 * stride_a); - a13 = vld1q_f32(a_ptr + 6 * stride_a + 4); - a14 = vld1q_f32(a_ptr + 7 * stride_a); - a15 = vld1q_f32(a_ptr + 7 * stride_a + 4); - - const float *b_ptr0 = B + k * stride_b; - const float *b_ptr1 = B + (k + 1) * stride_b; - const float *b_ptr2 = B + (k + 2) * stride_b; - const float *b_ptr3 = B + (k + 3) * stride_b; - const float *b_ptr4 = B + (k + 4) * stride_b; - const float *b_ptr5 = B + (k + 5) * stride_b; - const float *b_ptr6 = B + (k + 6) * stride_b; - const float *b_ptr7 = B + (k + 7) * stride_b; - - float *c_ptr0 = C + h * stride_c; - float *c_ptr1 = C + (h + 1) * stride_c; - float *c_ptr2 = C + (h + 2) * stride_c; - float *c_ptr3 = C + (h + 3) * stride_c; - float *c_ptr4 = C + (h + 4) * stride_c; - float *c_ptr5 = C + (h + 5) * stride_c; - float *c_ptr6 = C + (h + 6) * stride_c; - float *c_ptr7 = C + (h + 7) * stride_c; - - asm volatile( - "0: \n" - - "prfm pldl1keep, [%9, #128] \n" - "ld1 {v16.4s}, [%9], #16 \n" - - "prfm pldl1keep, [%1, #128] \n" - "ld1 {v18.4s}, [%1] \n" - - "prfm pldl1keep, [%2, #128] \n" - "ld1 {v19.4s}, [%2] \n" - - "prfm pldl1keep, [%3, #128] \n" - "ld1 {v20.4s}, [%3] \n" - "prfm pldl1keep, [%4, #128] \n" - "ld1 {v21.4s}, [%4] \n" - "prfm pldl1keep, [%5, #128] \n" - "ld1 {v22.4s}, [%5] \n" - "prfm pldl1keep, [%6, #128] \n" - "ld1 {v23.4s}, [%6] \n" - "prfm pldl1keep, [%7, #128] \n" - "ld1 {v24.4s}, [%7] \n" - "prfm pldl1keep, [%8, #128] \n" - "ld1 {v25.4s}, [%8] \n" - "prfm pldl1keep, [%10, #128] \n" - "ld1 {v17.4s}, [%10], #16 \n" - - "fmla v18.4s, v16.4s, %34.s[0] \n" - "fmla v19.4s, v16.4s, %35.s[0] \n" - "fmla v20.4s, v16.4s, %36.s[0] \n" - "fmla v21.4s, v16.4s, %37.s[0] \n" - - "fmla v22.4s, v16.4s, %38.s[0] \n" - "fmla v23.4s, v16.4s, %39.s[0] \n" - "fmla v24.4s, v16.4s, %40.s[0] \n" - "fmla v25.4s, v16.4s, %41.s[0] \n" - - "fmla v18.4s, v17.4s, %34.s[1] \n" - "fmla v19.4s, v17.4s, %35.s[1] \n" - "fmla v20.4s, v17.4s, %36.s[1] \n" - "fmla v21.4s, v17.4s, %37.s[1] \n" - - "prfm pldl1keep, [%11, #128] \n" - "ld1 {v16.4s}, [%11], #16 \n" - - "fmla v22.4s, v17.4s, %38.s[1] \n" - "fmla v23.4s, v17.4s, %39.s[1] \n" - "fmla v24.4s, v17.4s, %40.s[1] \n" - "fmla v25.4s, v17.4s, %41.s[1] \n" - - "fmla v18.4s, v16.4s, %34.s[2] \n" - "fmla v19.4s, v16.4s, %35.s[2] \n" - "fmla v20.4s, v16.4s, %36.s[2] \n" - "fmla v21.4s, v16.4s, %37.s[2] \n" - - "prfm pldl1keep, [%12, #128] \n" - "ld1 {v17.4s}, [%12], #16 \n" - - "fmla v22.4s, v16.4s, %38.s[2] \n" - "fmla v23.4s, v16.4s, %39.s[2] \n" - "fmla v24.4s, v16.4s, %40.s[2] \n" - "fmla v25.4s, v16.4s, %41.s[2] \n" - - "fmla v18.4s, v17.4s, %34.s[3] \n" - "fmla v19.4s, v17.4s, %35.s[3] \n" - "fmla v20.4s, v17.4s, %36.s[3] \n" - "fmla v21.4s, v17.4s, %37.s[3] \n" - - "prfm pldl1keep, [%13, #128] \n" - "ld1 {v16.4s}, [%13], #16 \n" - - "fmla v22.4s, v17.4s, %38.s[3] \n" - "fmla v23.4s, v17.4s, %39.s[3] \n" - "fmla v24.4s, v17.4s, %40.s[3] \n" - "fmla v25.4s, v17.4s, %41.s[3] \n" - - "fmla v18.4s, v16.4s, %42.s[0] \n" - "fmla v19.4s, v16.4s, %43.s[0] \n" - "fmla v20.4s, v16.4s, %44.s[0] \n" - "fmla v21.4s, v16.4s, %45.s[0] \n" - - "prfm pldl1keep, [%14, #128] \n" - "ld1 {v17.4s}, [%14], #16 \n" - - "fmla v22.4s, v16.4s, %46.s[0] \n" - "fmla v23.4s, v16.4s, %47.s[0] \n" - "fmla v24.4s, v16.4s, %48.s[0] \n" - "fmla v25.4s, v16.4s, %49.s[0] \n" - - "fmla v18.4s, v17.4s, %42.s[1] \n" - "fmla v19.4s, v17.4s, %43.s[1] \n" - "fmla v20.4s, v17.4s, %44.s[1] \n" - "fmla v21.4s, v17.4s, %45.s[1] \n" - - "prfm pldl1keep, [%15, #128] \n" - "ld1 {v16.4s}, [%15], #16 \n" - - "fmla v22.4s, v17.4s, %46.s[1] \n" - "fmla v23.4s, v17.4s, %47.s[1] \n" - "fmla v24.4s, v17.4s, %48.s[1] \n" - "fmla v25.4s, v17.4s, %49.s[1] \n" - - "fmla v18.4s, v16.4s, %42.s[2] \n" - "fmla v19.4s, v16.4s, %43.s[2] \n" - "fmla v20.4s, v16.4s, %44.s[2] \n" - "fmla v21.4s, v16.4s, %45.s[2] \n" - - "prfm pldl1keep, [%16, #128] \n" - "ld1 {v17.4s}, [%16], #16 \n" - - "fmla v22.4s, v16.4s, %46.s[2] \n" - "fmla v23.4s, v16.4s, %47.s[2] \n" - "fmla v24.4s, v16.4s, %48.s[2] \n" - "fmla v25.4s, v16.4s, %49.s[2] \n" - - "fmla v18.4s, v17.4s, %42.s[3] \n" - "fmla v19.4s, v17.4s, %43.s[3] \n" - "fmla v20.4s, v17.4s, %44.s[3] \n" - "fmla v21.4s, v17.4s, %45.s[3] \n" - - "st1 {v18.4s}, [%1], #16 \n" - "st1 {v19.4s}, [%2], #16 \n" - "st1 {v20.4s}, [%3], #16 \n" - "st1 {v21.4s}, [%4], #16 \n" - - "fmla v22.4s, v17.4s, %46.s[3] \n" - "fmla v23.4s, v17.4s, %47.s[3] \n" - "fmla v24.4s, v17.4s, %48.s[3] \n" - "fmla v25.4s, v17.4s, %49.s[3] \n" - - "subs %w0, %w0, #1 \n" - - "st1 {v22.4s}, [%5], #16 \n" - "st1 {v23.4s}, [%6], #16 \n" - "st1 {v24.4s}, [%7], #16 \n" - "st1 {v25.4s}, [%8], #16 \n" - - "bne 0b \n" - : "=r"(nw), // 0 - "=r"(c_ptr0), // 1 - "=r"(c_ptr1), // 2 - "=r"(c_ptr2), // 3 - "=r"(c_ptr3), // 4 - "=r"(c_ptr4), // 5 - "=r"(c_ptr5), // 6 - "=r"(c_ptr6), // 7 - "=r"(c_ptr7), // 8 - "=r"(b_ptr0), // 9 - "=r"(b_ptr1), // 10 - "=r"(b_ptr2), // 11 - "=r"(b_ptr3), // 12 - "=r"(b_ptr4), // 13 - "=r"(b_ptr5), // 14 - "=r"(b_ptr6), // 15 - "=r"(b_ptr7) // 16 - : "0"(nw), // 17 - "1"(c_ptr0), // 18 - "2"(c_ptr1), // 19 - "3"(c_ptr2), // 20 - "4"(c_ptr3), // 21 - "5"(c_ptr4), // 22 - "6"(c_ptr5), // 23 - "7"(c_ptr6), // 24 - "8"(c_ptr7), // 25 - "9"(b_ptr0), // 26 - "10"(b_ptr1), // 27 - "11"(b_ptr2), // 28 - "12"(b_ptr3), // 29 - "13"(b_ptr4), // 30 - "14"(b_ptr5), // 31 - "15"(b_ptr6), // 32 - "16"(b_ptr7), // 33 - "w"(a0), // 34 - "w"(a2), // 35 - "w"(a4), // 36 - "w"(a6), // 37 - "w"(a8), // 38 - "w"(a10), // 39 - "w"(a12), // 40 - "w"(a14), // 41 - "w"(a1), // 42 - "w"(a3), // 43 - "w"(a5), // 44 - "w"(a7), // 45 - "w"(a9), // 46 - "w"(a11), // 47 - "w"(a13), // 48 - "w"(a15) // 49 - : "cc", "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", - "v23", "v24", "v25"); - - w = (width >> 2) << 2; - } -#elif defined(__aarch64__) // gcc - for (w = 0; w + 3 < width; w += 4) { - const float *b_ptr = B + (k * stride_b + w); - float *c_ptr = C + (h * stride_c + w); - Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - } -#else // armv7 - int nw = width >> 2; - if (nw > 0) { - float32x4_t a0, a1, a2, a3, a4, a5; - - a0 = vld1q_f32(a_ptr); - a1 = vld1q_f32(a_ptr + 1 * stride_a); - a2 = vld1q_f32(a_ptr + 2 * stride_a); - a3 = vld1q_f32(a_ptr + 3 * stride_a); - a4 = vld1q_f32(a_ptr + 4 * stride_a); - a5 = vld1q_f32(a_ptr + 5 * stride_a); - - const float *b_ptr0 = B + k * stride_b; - const float *b_ptr1 = B + (k + 1) * stride_b; - const float *b_ptr2 = B + (k + 2) * stride_b; - const float *b_ptr3 = B + (k + 3) * stride_b; - - float *c_ptr0 = C + h * stride_c; - float *c_ptr1 = C + (h + 1) * stride_c; - float *c_ptr2 = C + (h + 2) * stride_c; - float *c_ptr3 = C + (h + 3) * stride_c; - float *c_ptr4 = C + (h + 4) * stride_c; - float *c_ptr5 = C + (h + 5) * stride_c; - - asm volatile( - "0: \n" - - "pld [%7, #128] \n" - "vld1.f32 {d12-d13}, [%7]! \n" - "pld [%1, #128] \n" - "vld1.f32 {d16-d17}, [%1] \n" - "pld [%2, #128] \n" - "vld1.f32 {d18-d19}, [%2] \n" - - "pld [%3, #128] \n" - "vld1.f32 {d20-d21}, [%3] \n" - "pld [%4, #128] \n" - "vld1.f32 {d22-d23}, [%4] \n" - "pld [%5, #128] \n" - "vld1.f32 {d24-d25}, [%5] \n" - "pld [%6, #128] \n" - "vld1.f32 {d26-d27}, [%6] \n" - - "pld [%8, #128] \n" - "vld1.f32 {d14-d15}, [%8]! \n" - - "vmla.f32 q8, q6, %e22[0] \n" - "vmla.f32 q9, q6, %e23[0] \n" - "vmla.f32 q10, q6, %e24[0] \n" - "vmla.f32 q11, q6, %e25[0] \n" - "vmla.f32 q12, q6, %e26[0] \n" - "vmla.f32 q13, q6, %e27[0] \n" - - "pld [%9, #128] \n" - "vld1.f32 {d12-d13}, [%9]! \n" - - "vmla.f32 q8, q7, %e22[1] \n" - "vmla.f32 q9, q7, %e23[1] \n" - "vmla.f32 q10, q7, %e24[1] \n" - "vmla.f32 q11, q7, %e25[1] \n" - "vmla.f32 q12, q7, %e26[1] \n" - "vmla.f32 q13, q7, %e27[1] \n" - - "pld [%10, #128] \n" - "vld1.f32 {d14-d15}, [%10]! \n" - - "vmla.f32 q8, q6, %f22[0] \n" - "vmla.f32 q9, q6, %f23[0] \n" - "vmla.f32 q10, q6, %f24[0] \n" - "vmla.f32 q11, q6, %f25[0] \n" - "vmla.f32 q12, q6, %f26[0] \n" - "vmla.f32 q13, q6, %f27[0] \n" - - "vmla.f32 q8, q7, %f22[1] \n" - "vmla.f32 q9, q7, %f23[1] \n" - "vmla.f32 q10, q7, %f24[1] \n" - "vmla.f32 q11, q7, %f25[1] \n" - "vmla.f32 q12, q7, %f26[1] \n" - "vmla.f32 q13, q7, %f27[1] \n" - - "vst1.f32 {d16-d17}, [%1]! \n" - "vst1.f32 {d18-d19}, [%2]! \n" - "vst1.f32 {d20-d21}, [%3]! \n" - "vst1.f32 {d22-d23}, [%4]! \n" - "vst1.f32 {d24-d25}, [%5]! \n" - "vst1.f32 {d26-d27}, [%6]! \n" - - "subs %0, #1 \n" - "bne 0b \n" - : "=r"(nw), // 0 - "=r"(c_ptr0), // 1 - "=r"(c_ptr1), // 2 - "=r"(c_ptr2), // 3 - "=r"(c_ptr3), // 4 - "=r"(c_ptr4), // 5 - "=r"(c_ptr5), // 6 - "=r"(b_ptr0), // 7 - "=r"(b_ptr1), // 8 - "=r"(b_ptr2), // 9 - "=r"(b_ptr3) // 10 - : "0"(nw), // 11 - "1"(c_ptr0), // 12 - "2"(c_ptr1), // 13 - "3"(c_ptr2), // 14 - "4"(c_ptr3), // 15 - "5"(c_ptr4), // 16 - "6"(c_ptr5), // 17 - "7"(b_ptr0), // 18 - "8"(b_ptr1), // 19 - "9"(b_ptr2), // 20 - "10"(b_ptr3), // 21 - "w"(a0), // 22 - "w"(a1), // 23 - "w"(a2), // 24 - "w"(a3), // 25 - "w"(a4), // 26 - "w"(a5) // 27 - : "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12", - "q13", "q14", "q15"); - - w = (width >> 2) << 2; - } -#endif - if (w < width) { - const float *b_ptr = B + (k * stride_b + w); - float *c_ptr = C + (h * stride_c + w); - GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w, - stride_a, stride_b, stride_c, c_ptr); - } - } - if (k < K) { - const float *a_ptr = A + (h * stride_a + k); - const float *b_ptr = B + k * stride_b; - float *c_ptr = C + h * stride_c; - GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b, - stride_c, c_ptr); - } - } - if (h < height) { - index_t remain_h = height - h; - - auto gemm_fn = Gemm184; - switch (remain_h) { - case 1: - #if defined(__aarch64__) - gemm_fn = Gemm184; - #else - gemm_fn = Gemm144; - #endif - break; - case 2: - #if defined(__aarch64__) - gemm_fn = Gemm284; - #else - gemm_fn = Gemm244; - #endif - break; - case 3: - #if defined(__aarch64__) - gemm_fn = Gemm384; - #else - gemm_fn = Gemm344; - #endif - break; - case 4: - #if defined(__aarch64__) - gemm_fn = Gemm484; - #else - gemm_fn = Gemm444; - #endif - break; - case 5: - #if defined(__aarch64__) - gemm_fn = Gemm584; - #else - gemm_fn = Gemm544; - #endif - break; - case 6: - #if defined(__aarch64__) - gemm_fn = Gemm684; - #else - LOG(FATAL) << "remain_h should < 6"; - #endif - break; - case 7: - #if defined(__aarch64__) - gemm_fn = Gemm784; - #else - LOG(FATAL) << "remain_h should < 6"; - #endif - break; - default: - LOG(FATAL) << "remain_h should < 8"; - } - - for (k = 0; k < K - reg_K_tile; k += reg_K_tile) { - const float *a_ptr = A + (h * stride_a + k); - index_t w; - for (w = 0; w + 3 < width; w += 4) { - const float *b_ptr = B + (k * stride_b + w); - float *c_ptr = C + (h * stride_c + w); - gemm_fn(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); - } - if (w < width) { - const float *b_ptr = B + (k * stride_b + w); - float *c_ptr = C + (h * stride_c + w); - GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a, - stride_b, stride_c, c_ptr); - } - } - if (k < K) { - const float *a_ptr = A + (h * stride_a + k); - const float *b_ptr = B + k * stride_b; - float *c_ptr = C + h * stride_c; - GemmBlock(a_ptr, b_ptr, remain_h, K - k, width, stride_a, stride_b, - stride_c, c_ptr); - } - } -#else - GemmBlock(A, B, height, K, width, stride_a, stride_b, stride_c, C); -#endif // MACE_ENABLE_NEON -} - -} // namespace - -void Transpose(const float *src, - index_t height, - index_t width, - index_t stride_w, - float *dst) { - index_t tile_size = height > 512 || width > 512 ? 64 : 32; - for (index_t i = 0; i < height; i += tile_size) { - for (index_t j = 0; j < width; j += tile_size) { - index_t end_i = std::min(i + tile_size, height); - index_t end_j = std::min(j + tile_size, width); - for (index_t tile_i = i; tile_i < end_i; ++tile_i) { - for (index_t tile_j = j; tile_j < end_j; ++tile_j) { - dst[tile_j * height + tile_i] = src[tile_i * stride_w + tile_j]; - } - } - } - } -} - -// A: height x K, B: K x width, C: height x width -void Gemm(const float *A, - const float *B, - const index_t batch, - const index_t height, - const index_t K, - const index_t width, - float *C, - const bool transpose_a, - const bool transpose_b) { - if (width == 1 && !transpose_a) { - for (index_t b = 0; b < batch; ++b) { - Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height); - } - return; - } - memset(C, 0, sizeof(float) * batch * height * width); - - std::vector block_size_dims {height, width, K}; - index_t thread_count = MaceOpenMPThreadCount; - MACE_CHECK(thread_count >= 1, "thread should be ge 1"); - // TODO(liyin): apply gcd ? - if (height % thread_count == 0) { - block_size_dims[0] = height / thread_count; - } else if (thread_count == 4 && (height & 1) == 0 && (width & 1) == 0) { - block_size_dims[0] = height >> 1; - block_size_dims[1] = width >> 1; - } else if (width % thread_count == 0) { - block_size_dims[1] = width / thread_count; - } else { - if (height >= thread_count) { - block_size_dims[0] = height / thread_count; - } else { - thread_count = std::min(thread_count, height * width); - index_t thread_h = height; - index_t thread_w = RoundUpDiv(thread_count, thread_h); - block_size_dims[0] = 1; - block_size_dims[1] = std::max(static_cast(1), width / thread_w); - } - } - - const index_t block_tile[3] = {height / block_size_dims[0], - width / block_size_dims[1], - K / block_size_dims[2]}; - block_size_dims[0] = height / block_tile[0]; - block_size_dims[1] = width / block_tile[1]; - block_size_dims[2] = K / block_tile[2]; - - const index_t remain[3] = {height % block_tile[0], - width % block_tile[1], - K % block_tile[2]}; - - -#pragma omp parallel for collapse(3) - for (index_t n = 0; n < batch; ++n) { - for (index_t bh = 0; bh < block_tile[0]; ++bh) { - for (index_t bw = 0; bw < block_tile[1]; ++bw) { - const index_t remain_height = remain[0]; - const index_t remain_width = remain[1]; - const index_t remain_k = remain[2]; - - const index_t block_size_height = block_size_dims[0]; - const index_t block_size_width = block_size_dims[1]; - const index_t block_size_k = block_size_dims[2]; - - const index_t this_block_size_height = - block_size_height + (bh < remain_height ? 1 : 0); - const index_t this_block_size_width = - block_size_width + (bw < remain_width ? 1 : 0); - - const float *a_base = A + n * height * K; - const float *b_base = B + n * K * width; - float *c_base = C + n * height * width; - - const index_t ih_begin = - bh * block_size_height + (bh < remain_height ? bh : remain_height); - const index_t - ih_end = std::min(height, ih_begin + this_block_size_height); - const index_t iw_begin = - bw * block_size_width + (bw < remain_width ? bw : remain_width); - const index_t - iw_end = std::min(width, iw_begin + this_block_size_width); - - for (index_t bk = 0; bk < block_tile[2]; ++bk) { - const index_t - this_block_size_k = block_size_k + (bk < remain_k ? 1 : 0); - - const index_t - ik_begin = bk * block_size_k + (bk < remain_k ? bk : remain_k); - const index_t ik_end = std::min(K, ik_begin + this_block_size_k); - - Tensor trans_a(GetCPUAllocator(), DataType::DT_FLOAT); - Tensor trans_b(GetCPUAllocator(), DataType::DT_FLOAT); - const float *real_a = nullptr; - const float *real_b = nullptr; - float *real_c = c_base + (ih_begin * width + iw_begin); - index_t stride_a; - index_t stride_b; - index_t stride_c = width; - - if (transpose_a) { - trans_a.Resize({this_block_size_height, this_block_size_k}); - float *trans_a_data = trans_a.mutable_data(); - // A[K, H] -> A[H, K] - Transpose(a_base + (ik_begin * height + ih_begin), - ik_end - ik_begin, ih_end - ih_begin, height, - trans_a_data); - real_a = trans_a_data; - stride_a = ik_end - ik_begin; - } else { - real_a = a_base + (ih_begin * K + ik_begin); - stride_a = K; - } - - if (transpose_b) { - trans_b.Resize({this_block_size_k, this_block_size_width}); - float *trans_b_data = trans_b.mutable_data(); - // B[W, K] -> B[K, W] - Transpose(b_base + (iw_begin * K + ik_begin), iw_end - iw_begin, - ik_end - ik_begin, K, trans_b_data); - real_b = trans_b_data; - stride_b = iw_end - iw_begin; - } else { - real_b = b_base + (ik_begin * width + iw_begin); - stride_b = width; - } - - // inside block: - // calculate C[bh, bw] += A[bh, bk] * B[bk, bw] for one k - GemmTile(real_a, real_b, ih_end - ih_begin, ik_end - ik_begin, - iw_end - iw_begin, stride_a, stride_b, stride_c, real_c); - } // bk - } // bw - } // bh - } // n -} - -// A: height x K, B: K x width, C: height x width -void GemmRef(const float *A, - const float *B, - const index_t batch, - const index_t height, - const index_t K, - const index_t width, - float *C, - const bool transpose_a, - const bool transpose_b) { - memset(C, 0, sizeof(float) * batch * height * width); - - Tensor trans_a(GetCPUAllocator(), DataType::DT_FLOAT); - Tensor trans_b(GetCPUAllocator(), DataType::DT_FLOAT); - float *trans_a_data = nullptr; - float *trans_b_data = nullptr; - if (transpose_a) { - trans_a.Resize({height, K}); - trans_a_data = trans_a.mutable_data(); - } - if (transpose_b) { - trans_b.Resize({K, width}); - trans_b_data = trans_b.mutable_data(); - } - - for (index_t b = 0; b < batch; ++b) { - const float *real_a = nullptr; - const float *real_b = nullptr; - float *real_c = C + b * height * width; - if (transpose_a) { - // A[K, H] -> A[H, K] - Transpose(A + b * height * K, K, height, height, trans_a_data); - real_a = trans_a_data; - } else { - real_a = A + b * height * K; - } - if (transpose_b) { - // B[W, K] -> B[K, W] - Transpose(B + b * width * K, width, K, K, trans_b_data); - real_b = trans_b_data; - } else { - real_b = B + b * width * K; - } - - for (index_t i = 0; i < height; ++i) { - for (index_t j = 0; j < width; ++j) { - for (index_t k = 0; k < K; ++k) { - real_c[i * width + j] += real_a[i * K + k] * real_b[k * width + j]; - } - } - } - } -} - -void GemvRef(const float *m_ptr, - const float *v_ptr, - const index_t batch, - const index_t width, - const index_t height, - float *out_ptr) { - memset(out_ptr, 0, batch * height * sizeof(float)); -#pragma omp parallel for collapse(2) - for (int b = 0; b < batch; ++b) { - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - out_ptr[b * height + h] += v_ptr[b * width + w] * m_ptr[h * width + w]; - } - } - } -} - -void Gemv(const float *m_ptr, - const float *v_ptr, - const index_t batch, - const index_t width, - const index_t height, - float *out_ptr) { -#if defined(MACE_ENABLE_NEON) - -#pragma omp parallel for collapse(2) - for (index_t b = 0; b < batch; ++b) { - for (index_t h = 0; h < height; ++h) { - const float *m_ptr0 = m_ptr + h * width; - const float *v_ptr0 = v_ptr + b * width; - float *out_ptr0 = out_ptr + b * height + h; - - float32x4_t vm0, vm1, vm2, vm3; - float32x4_t vv0, vv1, vv2, vv3; - float32x4_t vsum0 = vdupq_n_f32(0.f); - float32x4_t vsum1 = vdupq_n_f32(0.f); - float32x4_t vsum2 = vdupq_n_f32(0.f); - float32x4_t vsum3 = vdupq_n_f32(0.f); - - index_t w; - for (w = 0; w + 15 < width; w += 16) { - vm0 = vld1q_f32(m_ptr0); - vv0 = vld1q_f32(v_ptr0); - vm1 = vld1q_f32(m_ptr0 + 4); - vv1 = vld1q_f32(v_ptr0 + 4); - vm2 = vld1q_f32(m_ptr0 + 8); - vv2 = vld1q_f32(v_ptr0 + 8); - vm3 = vld1q_f32(m_ptr0 + 12); - vv3 = vld1q_f32(v_ptr0 + 12); - - vsum0 = vmlaq_f32(vsum0, vm0, vv0); - vsum1 = vmlaq_f32(vsum1, vm1, vv1); - vsum2 = vmlaq_f32(vsum2, vm2, vv2); - vsum3 = vmlaq_f32(vsum3, vm3, vv3); - - m_ptr0 += 16; - v_ptr0 += 16; - } - - for (; w + 7 < width; w += 8) { - vm0 = vld1q_f32(m_ptr0); - vv0 = vld1q_f32(v_ptr0); - vm1 = vld1q_f32(m_ptr0 + 4); - vv1 = vld1q_f32(v_ptr0 + 4); - - vsum0 = vmlaq_f32(vsum0, vm0, vv0); - vsum1 = vmlaq_f32(vsum1, vm1, vv1); - - m_ptr0 += 8; - v_ptr0 += 8; - } - - for (; w + 3 < width; w += 4) { - vm0 = vld1q_f32(m_ptr0); - vv0 = vld1q_f32(v_ptr0); - vsum0 = vmlaq_f32(vsum0, vm0, vv0); - - m_ptr0 += 4; - v_ptr0 += 4; - } - vsum0 += vsum1; - vsum2 += vsum3; - vsum0 += vsum2; - float sum0 = vaddvq_f32(vsum0); - - // handle remaining w - for (; w < width; ++w) { - sum0 += m_ptr0[0] * v_ptr0[0]; - m_ptr0++; - v_ptr0++; - } - *out_ptr0++ = sum0; - } // h - } // b -#else - GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr); -#endif -} - -} // namespace ops -} // namespace mace diff --git a/mace/ops/gemm.h b/mace/ops/gemm.h deleted file mode 100644 index ecd228e91f22bac4e6f090879a0080b30ac73a7f..0000000000000000000000000000000000000000 --- a/mace/ops/gemm.h +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2018 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_OPS_GEMM_H_ -#define MACE_OPS_GEMM_H_ - -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) -#include -#endif - -#include "mace/core/types.h" - -// Gemm function does fast matrix-matrix multiplications with batch. -// Gemv function does fast matrix-vector multiplications with batch. - -namespace mace { -namespace ops { - -// Gemm calculates A[batch, height, K] dot B[batch, K, width] within each batch, -// and output to C[batch, height, width]. -// height, K, width correspond to matrix dimension size after transpose (if any) -void Gemm(const float *A, - const float *B, - const index_t batch, - const index_t height, - const index_t K, - const index_t width, - float *C, - const bool transpose_a = false, - const bool transpose_b = false); - -void GemmRef(const float *A, - const float *B, - const index_t batch, - const index_t height, - const index_t K, - const index_t width, - float *C, - const bool transpose_a = false, - const bool transpose_b = false); - -// Gemm calculates M[height, width] dot V[batch, height] within each batch of V, -// and output to out[batch, width]. -void Gemv(const float *m_ptr, - const float *v_ptr, - const index_t batch, - const index_t width, - const index_t height, - float *out_ptr); - -void GemvRef(const float *m_ptr, - const float *v_ptr, - const index_t batch, - const index_t width, - const index_t height, - float *out_ptr); - -void Transpose(const float *src, - index_t height, - index_t width, - index_t stride_w, - float *dst); - -} // namespace ops -} // namespace mace - -#endif // MACE_OPS_GEMM_H_ diff --git a/mace/ops/gemm_test.cc b/mace/ops/gemm_test.cc deleted file mode 100644 index 9b2adefb59406966ef0cd1876c1905a0f8e168f4..0000000000000000000000000000000000000000 --- a/mace/ops/gemm_test.cc +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright 2018 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "mace/core/types.h" -#include "mace/ops/gemm.h" -#include "mace/ops/sgemm.h" - -namespace mace { - -namespace { - -void GemmTest(index_t batch, - index_t N, - index_t K, - index_t M, - bool transpose_a, - bool transpose_b) { - std::unique_ptr A(new float[batch * N * K]); - std::unique_ptr B(new float[batch * K * M]); - std::unique_ptr C(new float[batch * N * M]); - std::unique_ptr C_ref(new float[batch * N * M]); - - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); - - std::generate(A.get(), A.get() + batch * N * K, - [&gen, &nd] { return nd(gen); }); - std::generate(B.get(), B.get() + batch * K * M, - [&gen, &nd] { return nd(gen); }); - ops::Gemm(A.get(), B.get(), batch, N, K, M, C.get(), transpose_a, - transpose_b); - ops::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a, - transpose_b); - - for (int i = 0; i < batch * N * M; ++i) { - EXPECT_NEAR(C_ref[i], C[i], 0.1); - } -} - -void GemvTest(index_t batch, index_t N, index_t M) { - std::unique_ptr A(new float[N * M]); - std::unique_ptr B(new float[batch * M]); - std::unique_ptr C(new float[batch * N]); - std::unique_ptr C_ref(new float[batch * N]); - - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); - - std::generate(A.get(), A.get() + N * M, [&gen, &nd] { return nd(gen); }); - std::generate(B.get(), B.get() + batch * M, [&gen, &nd] { return nd(gen); }); - ops::Gemv(A.get(), B.get(), batch, M, N, C.get()); - ops::GemvRef(A.get(), B.get(), batch, M, N, C_ref.get()); - - for (int i = 0; i < batch * N; ++i) { - EXPECT_NEAR(C_ref[i], C[i], 0.1); - } -} - -void SGemmTest(index_t batch, - index_t N, - index_t K, - index_t M, - bool transpose_a, - bool transpose_b) { - std::unique_ptr A(new float[batch * N * K]); - std::unique_ptr B(new float[batch * K * M]); - std::unique_ptr C(new float[batch * N * M]); - std::unique_ptr C_ref(new float[batch * N * M]); - - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); - - std::generate(A.get(), A.get() + batch * N * K, - [&gen, &nd] { return nd(gen); }); - std::generate(B.get(), B.get() + batch * K * M, - [&gen, &nd] { return nd(gen); }); - ops::GemmRef(A.get(), B.get(), batch, N, K, M, C_ref.get(), transpose_a, - transpose_b); - - ops::MatrixMap matrix_a; - ops::MatrixMap matrix_b; - - if (!transpose_a) { - matrix_a = - ops::MatrixMap(batch, - N, - K, - ops::RowMajor, - A.get()); - } else { - matrix_a = - ops::MatrixMap(batch, - K, - N, - ops::RowMajor, - A.get()); - matrix_a = matrix_a.transpose(); - } - - if (!transpose_b) { - matrix_b = - ops::MatrixMap(batch, - K, - M, - ops::RowMajor, - B.get()); - } else { - matrix_b = - ops::MatrixMap(batch, - M, - K, - ops::RowMajor, - B.get()); - matrix_b = matrix_b.transpose(); - } - ops::MatrixMap matrix_c(batch, N, M, ops::RowMajor, C.get()); - - ops::SGemm sgemm; - sgemm(matrix_a, matrix_b, &matrix_c); - - for (int i = 0; i < N * M; ++i) { - EXPECT_NEAR(C_ref[i], C[i], 0.1); - } -} - -} // namespace - -TEST(GEMMTest, AlignedWithoutBatch) { - GemmTest(1, 1, 64, 128, false, false); - GemmTest(1, 2, 64, 128, false, true); - GemmTest(1, 3, 64, 128, true, false); - GemmTest(1, 4, 64, 128, true, true); - GemmTest(1, 5, 64, 128, false, false); - GemmTest(1, 6, 64, 128, false, true); - GemmTest(1, 7, 64, 128, true, false); - GemmTest(1, 17, 64, 128, true, true); - GemmTest(1, 256, 128, 4096, false, false); - GemmTest(1, 256, 128, 4104, false, false); -} - -TEST(GEMMTest, UnalignedWithoutBatch) { - GemmTest(1, 1, 63, 127, false, false); - GemmTest(1, 2, 63, 127, false, true); - GemmTest(1, 3, 63, 127, true, false); - GemmTest(1, 4, 63, 127, true, true); - GemmTest(1, 5, 63, 127, false, false); - GemmTest(1, 6, 63, 127, false, true); - GemmTest(1, 7, 63, 127, true, false); - GemmTest(1, 17, 63, 127, true, true); -} - -TEST(GEMMTest, UnalignedWithBatch) { - GemmTest(3, 1, 63, 127, false, false); - GemmTest(3, 2, 63, 127, false, true); - GemmTest(3, 3, 63, 127, true, false); - GemmTest(3, 4, 63, 127, true, true); - GemmTest(3, 5, 63, 127, false, false); - GemmTest(3, 6, 63, 127, false, true); - GemmTest(3, 7, 63, 127, true, false); - GemmTest(3, 17, 63, 127, true, true); -} - -TEST(GEMMTest, gemv) { - GemvTest(1, 17, 63); - GemvTest(3, 17, 63); -} - -namespace { -void TestSGemmTranspose(index_t batch, index_t N, index_t K, index_t M) { - SGemmTest(batch, N, K, M, false, false); - SGemmTest(batch, N, K, M, true, false); - SGemmTest(batch, N, K, M, false, true); - SGemmTest(batch, N, K, M, true, true); -} -} - -TEST(SGEMMTest, UnalignedWithoutBatch) { - std::vector tests{1, 5, 14, 31, 47}; - for (index_t N : tests) { - for (index_t K : tests) { - for (index_t M : tests) { - TestSGemmTranspose(1, N, K, M); - TestSGemmTranspose(16, N, K, M); - } - } - } -} - -} // namespace mace diff --git a/mace/ops/infer_conv2d_shape.cc b/mace/ops/infer_conv2d_shape.cc index f3fe1e073f98f33d73a5731ca1df39f4311d7d4b..cd0d96b8cef49abad2e97cd60a81619065d51ebb 100644 --- a/mace/ops/infer_conv2d_shape.cc +++ b/mace/ops/infer_conv2d_shape.cc @@ -14,7 +14,7 @@ #include "mace/core/operator.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { namespace ops { diff --git a/mace/ops/infer_conv2d_shape_test.cc b/mace/ops/infer_conv2d_shape_test.cc index 2b39a7ef49e8bbb6d5a5600ee81c76d582a793ce..feaaecff8364d9f1a3270105bc03ddb36e3f5be2 100644 --- a/mace/ops/infer_conv2d_shape_test.cc +++ b/mace/ops/infer_conv2d_shape_test.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "mace/ops/ops_test_util.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { namespace ops { diff --git a/mace/ops/matmul.cc b/mace/ops/matmul.cc index 23db4e5208c20243b404a034dc3aba1ae58d903f..7ae79569d26e8655e5bae4c6aab8e11a6b87f207 100644 --- a/mace/ops/matmul.cc +++ b/mace/ops/matmul.cc @@ -21,13 +21,23 @@ #include "mace/core/operator.h" #include "mace/core/tensor.h" -#include "mace/ops/gemm.h" #include "mace/ops/sgemm.h" #include "mace/utils/utils.h" +#ifdef MACE_ENABLE_NEON + +#include "mace/ops/arm/fp32/gemv.h" + +#ifdef MACE_ENABLE_QUANTIZE +#include "mace/ops/arm/q8/gemv.h" +#endif // MACE_ENABLE_QUANTIZE + +#else +#include "mace/ops/ref/gemv.h" +#endif // MACE_ENABLE_NEON + #ifdef MACE_ENABLE_QUANTIZE #include "mace/ops/gemmlowp_util.h" -#include "mace/ops/arm/fixpoint_gemm.h" #endif // MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_OPENCL @@ -106,7 +116,6 @@ class MatMulOp : public MatMulOpBase { } batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, std::multiplies()); - std::vector c_shape = A->shape(); c_shape[rank - 2] = height; c_shape[rank - 1] = width; @@ -141,11 +150,17 @@ class MatMulOp : public MatMulOpBase { B->is_weight(), c_ptr_base, context->device()->scratch_buffer()); + return MaceStatus::MACE_SUCCESS; } private: SGemm sgemm_; +#ifdef MACE_ENABLE_NEON + arm::fp32::Gemv gemv_; +#else + ref::Gemv gemv_; +#endif // MACE_ENABLE_NEON }; #ifdef MACE_ENABLE_QUANTIZE @@ -163,38 +178,54 @@ class MatMulFixpointImpl { const index_t K, const index_t width, Tensor *C) { - Tensor::MappingGuard guarda(A); - Tensor::MappingGuard guardb(B); - Tensor::MappingGuard guardc(C); - auto a_ptr_base = A->data(); - auto b_ptr_base = B->data(); - auto c_ptr_base = C->mutable_data(); index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, std::multiplies()); - auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); - MACE_CHECK_NOTNULL(gemm_context); - - index_t a_size = height * K; - index_t b_size = K * width; - index_t c_size = height * width; - - const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias( - A->scale(), B->scale(), C->scale(), C->zero_point()); - - for (index_t i = 0; i < batch; ++i) { - gemmlowp::MatrixMap - a_matrix(a_ptr_base + i * a_size, height, K); - gemmlowp::MatrixMap - b_matrix(b_ptr_base + i * b_size, K, width); - gemmlowp::MatrixMap - c_matrix(c_ptr_base + i * c_size, height, width); - - using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; - gemmlowp::GemmWithOutputPipeline( - gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), - -B->zero_point(), output_pipeline); + +#if defined(MACE_ENABLE_NEON) + if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { + gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C); + } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { + gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C); + } else { +#endif // MACE_ENABLE_NEON + Tensor::MappingGuard guarda(A); + Tensor::MappingGuard guardb(B); + Tensor::MappingGuard guardc(C); + auto a_ptr_base = A->data(); + auto b_ptr_base = B->data(); + auto c_ptr_base = C->mutable_data(); + + auto gemm_context = + context->device()->cpu_runtime()->GetGemmlowpContext(); + MACE_CHECK_NOTNULL(gemm_context); + + index_t a_size = height * K; + index_t b_size = K * width; + index_t c_size = height * width; + + const auto &output_pipeline = GemmlowpOutputPipeline::MakeNoBias( + A->scale(), B->scale(), C->scale(), C->zero_point()); + + for (index_t i = 0; i < batch; ++i) { + gemmlowp::MatrixMap + a_matrix(a_ptr_base + i * a_size, height, K); + gemmlowp::MatrixMap + b_matrix(b_ptr_base + i * b_size, K, width); + gemmlowp::MatrixMap + c_matrix(c_ptr_base + i * c_size, height, width); + + using BitDepthParams = gemmlowp::L8R8WithLhsNonzeroBitDepthParams; + gemmlowp::GemmWithOutputPipeline( + gemm_context, a_matrix, b_matrix, &c_matrix, -A->zero_point(), + -B->zero_point(), output_pipeline); + } } +#if defined(MACE_ENABLE_NEON) } + + private: + arm::q8::Gemv gemv_kernel_; +#endif // MACE_ENABLE_NEON }; template @@ -207,38 +238,24 @@ class MatMulFixpointImpl { const index_t K, const index_t width, Tensor *C) { - Tensor::MappingGuard guarda(A); - Tensor::MappingGuard guardb(B); - Tensor::MappingGuard guardc(C); - auto a_ptr_base = A->data(); - auto b_ptr_base = B->data(); - auto c_ptr_base = C->mutable_data(); + C->SetScale(A->scale() * B->scale()); + C->SetZeroPoint(0); index_t batch = std::accumulate(A->shape().begin(), A->shape().end() - 2, 1, std::multiplies()); +#if defined(MACE_ENABLE_NEON) if (width == 1 && AOrder == gemmlowp::MapOrder::RowMajor) { - // gemv - for (index_t i = 0; i < batch; ++i) { - FixPointGemv(a_ptr_base + i * height * K, - b_ptr_base + i * K, - A->zero_point(), - B->zero_point(), - height, - K, - c_ptr_base + i * height); - } + gemv_kernel_.Compute(context, A, B, nullptr, batch, height, K, true, C); } else if (height == 1 && BOrder == gemmlowp::MapOrder::ColMajor) { - // gevm - for (index_t i = 0; i < batch; ++i) { - FixPointGemv(b_ptr_base + i * K * width, - a_ptr_base + i * K, - B->zero_point(), - A->zero_point(), - width, - K, - c_ptr_base + i * width); - } + gemv_kernel_.Compute(context, B, A, nullptr, batch, width, K, true, C); } else { +#endif // MACE_ENABLE_NEON + Tensor::MappingGuard guarda(A); + Tensor::MappingGuard guardb(B); + Tensor::MappingGuard guardc(C); + auto a_ptr_base = A->data(); + auto b_ptr_base = B->data(); + auto c_ptr_base = C->mutable_data(); auto gemm_context = context->device()->cpu_runtime()->GetGemmlowpContext(); MACE_CHECK_NOTNULL(gemm_context); @@ -264,9 +281,12 @@ class MatMulFixpointImpl { } } - C->SetScale(A->scale() * B->scale()); - C->SetZeroPoint(0); +#if defined(MACE_ENABLE_NEON) } + + private: + arm::q8::Gemv gemv_kernel_; +#endif // MACE_ENABLE_NEON }; template <> diff --git a/mace/ops/matmul_benchmark.cc b/mace/ops/matmul_benchmark.cc index eac29c5ff0b50f9834e8ee7ef42227c87548ba07..07a51ebf8a1114b37164412deff14eb1c99e8694 100644 --- a/mace/ops/matmul_benchmark.cc +++ b/mace/ops/matmul_benchmark.cc @@ -21,7 +21,6 @@ #include "public/gemmlowp.h" #include "mace/benchmark/statistics.h" #include "mace/core/testing/test_benchmark.h" -#include "mace/ops/gemm.h" #include "mace/ops/sgemm.h" #include "mace/ops/ops_test_util.h" @@ -96,19 +95,6 @@ namespace test { namespace { // Matmul with (m, k) x (k, n) -void MatmulBenchmark_Mace(int iters, int m, int k, int n) { - mace::testing::StopTiming(); - std::vector lhs(m * k); - std::vector rhs(k * n); - std::vector result(m * n); - // warm up - Gemm(lhs.data(), rhs.data(), 1, m, k, n, result.data()); - mace::testing::StartTiming(); - while (iters--) { - Gemm(lhs.data(), rhs.data(), 1, m, k, n, result.data()); - } -} - void MatmulBenchmark_Mace_SGemm(int iters, int m, int k, int n) { mace::testing::StopTiming(); std::vector lhs(m * k); @@ -234,7 +220,6 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) { MACE_BENCHMARK(MACE_BM_MATMUL_##M##_##K##_##N##_##FUNC) #define MACE_BM_MATMUL(M, K, N) \ - MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \ MACE_BM_MATMUL_FUNC(M, K, N, Mace_SGemm, float); \ MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \ MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \ @@ -307,6 +292,7 @@ void MatMulBenchmark( .Input("A") .Input("B") .Output("Output") + .OutputType({DT_INT32}) .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); @@ -408,7 +394,7 @@ void MatMulTransposeBenchmark( MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \ MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU); -MACE_BM_MATMUL_OP(1, 128, 128, 49); +MACE_BM_MATMUL_OP(1, 30000, 256, 1); MACE_BM_MATMUL_OP(2, 128, 128, 49); MACE_BM_MATMUL_OP(3, 128, 128, 49); MACE_BM_MATMUL_OP(4, 128, 128, 49); diff --git a/mace/ops/matmul_test.cc b/mace/ops/matmul_test.cc index 35797052c36d693de1c837b904986d097ad85457..dc15485eee4441bbe650507a825e525b4bc17ba5 100644 --- a/mace/ops/matmul_test.cc +++ b/mace/ops/matmul_test.cc @@ -23,7 +23,7 @@ namespace test { class MatMulOpTest : public OpsTestBase {}; namespace { -template +template void Simple(const std::vector &A_shape, const std::vector &A_value, const std::vector &B_shape, @@ -55,12 +55,12 @@ TEST_F(MatMulOpTest, SimpleCPU) { Simple({1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 3, 2}, {1, 2, 3, 4, 5, 6}, {1, 2, 2}, {22, 28, 49, 64}); Simple( - {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + {1, 5, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}, - {1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610, - 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, + {1, 5, 5}, {215, 230, 245, 260, 275, 490, 530, 570, 610, + 650, 765, 830, 895, 960, 1025, 1040, 1130, 1220, 1310, 1400, 1315, 1430, 1545, 1660, 1775}); } @@ -289,9 +289,6 @@ TEST_F(MatMulOpTest, QuantOutputInt32) { QuantOutputInt32({2}, 253, 300, 1, false, false); } -// TODO(liyin): test transpose after implementing gpu runtime -// now transpose test is in kernels_test - } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/opencl/buffer/conv_2d_1x1.cc b/mace/ops/opencl/buffer/conv_2d_1x1.cc index 49bfb488e81d623760a05e164f357a8b44e73c86..bfe6775e91b0bf673365e2db4b634a57e10029bc 100644 --- a/mace/ops/opencl/buffer/conv_2d_1x1.cc +++ b/mace/ops/opencl/buffer/conv_2d_1x1.cc @@ -14,7 +14,7 @@ #include "mace/core/op_context.h" #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" #include "mace/ops/opencl/helper.h" namespace mace { diff --git a/mace/ops/opencl/buffer/conv_2d_general.cc b/mace/ops/opencl/buffer/conv_2d_general.cc index 1c066da30e92832a599236a0822106a3e1657811..f2090a1bb6d5d69b89a14bedb9118470c59c8c01 100644 --- a/mace/ops/opencl/buffer/conv_2d_general.cc +++ b/mace/ops/opencl/buffer/conv_2d_general.cc @@ -14,7 +14,7 @@ #include "mace/core/op_context.h" #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" #include "mace/ops/opencl/helper.h" namespace mace { diff --git a/mace/ops/opencl/buffer_transformer.h b/mace/ops/opencl/buffer_transformer.h index ab702b2340fddda0e9d932c054d14e2cbe601a8b..acefd6abdec2e3cfda6b9a25c13a64f2ed87e7b0 100644 --- a/mace/ops/opencl/buffer_transformer.h +++ b/mace/ops/opencl/buffer_transformer.h @@ -23,7 +23,7 @@ #include "mace/ops/opencl/image/buffer_to_image.h" #include "mace/ops/opencl/image/image_to_buffer.h" #include "mace/ops/opencl/buffer/buffer_transform.h" -#include "mace/ops/transpose.h" +#include "mace/ops/common/transpose.h" namespace mace { namespace ops { diff --git a/mace/ops/opencl/conv_2d.h b/mace/ops/opencl/conv_2d.h index ab8f876b552fdc43740d1f36570cf1cbcd517a6d..a9ec131d18ef898cb493f4f7ba0bc73fcacc7f07 100644 --- a/mace/ops/opencl/conv_2d.h +++ b/mace/ops/opencl/conv_2d.h @@ -18,7 +18,7 @@ #include #include "mace/ops/activation.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { class OpContext; diff --git a/mace/ops/opencl/depthwise_conv2d.h b/mace/ops/opencl/depthwise_conv2d.h index 5e17ff0759538d70cb22ae234ffbd27cb1d18b2c..98f97a2016eff313494beb79656fbdedfa15c5d4 100644 --- a/mace/ops/opencl/depthwise_conv2d.h +++ b/mace/ops/opencl/depthwise_conv2d.h @@ -17,8 +17,8 @@ #include -#include "mace/ops/activation.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/activation_type.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { diff --git a/mace/ops/opencl/depthwise_deconv2d.h b/mace/ops/opencl/depthwise_deconv2d.h index 700e8d7ca524f7ab75fc9ffa206ac74ca259ec71..b2460fcda74e67ff33c9e3dee10ba53dc840fff4 100644 --- a/mace/ops/opencl/depthwise_deconv2d.h +++ b/mace/ops/opencl/depthwise_deconv2d.h @@ -18,7 +18,7 @@ #include #include -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" namespace mace { diff --git a/mace/ops/opencl/image/activation.h b/mace/ops/opencl/image/activation.h index afddc774b654b073f4764161041a0c348040be0a..6f7c573cec0c3016ac247e095d6148da158e3301 100644 --- a/mace/ops/opencl/image/activation.h +++ b/mace/ops/opencl/image/activation.h @@ -23,7 +23,7 @@ #include "mace/core/op_context.h" #include "mace/core/tensor.h" -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" #include "mace/ops/opencl/helper.h" namespace mace { diff --git a/mace/ops/opencl/image/conv_2d_1x1.cc b/mace/ops/opencl/image/conv_2d_1x1.cc index c13416a591f164898115439df4e178cd091896ec..374d262ae34a4938e40f94dd941e95735bcedd4e 100644 --- a/mace/ops/opencl/image/conv_2d_1x1.cc +++ b/mace/ops/opencl/image/conv_2d_1x1.cc @@ -14,7 +14,7 @@ #include "mace/core/op_context.h" #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" #include "mace/ops/opencl/helper.h" namespace mace { diff --git a/mace/ops/opencl/image/conv_2d_3x3.cc b/mace/ops/opencl/image/conv_2d_3x3.cc index 62a3ba023ca127de3ab3c7f3327e770f5ab2eb28..db63300eb7607dead1cc9661533e0e7d463e5e4b 100644 --- a/mace/ops/opencl/image/conv_2d_3x3.cc +++ b/mace/ops/opencl/image/conv_2d_3x3.cc @@ -14,7 +14,7 @@ #include "mace/core/op_context.h" #include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" #include "mace/ops/opencl/helper.h" #include "mace/utils/utils.h" diff --git a/mace/ops/opencl/image/conv_2d_general.cc b/mace/ops/opencl/image/conv_2d_general.cc index c68faf0140f82f11db933efe9b245d981c81835f..08568a5d9e39d671a2e3d84de8fc1fa22c588f95 100644 --- a/mace/ops/opencl/image/conv_2d_general.cc +++ b/mace/ops/opencl/image/conv_2d_general.cc @@ -15,7 +15,7 @@ #include "mace/core/op_context.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/ops/opencl/helper.h" -#include "mace/ops/activation.h" +#include "mace/ops/common/activation_type.h" #include "mace/utils/utils.h" namespace mace { diff --git a/mace/ops/opencl/image/winograd_conv2d.cc b/mace/ops/opencl/image/winograd_conv2d.cc index a2fd811ff4363e73606488511c40b08a165e19a0..527d6cc87f0b8e5023100a9d403f363d66db5871 100644 --- a/mace/ops/opencl/image/winograd_conv2d.cc +++ b/mace/ops/opencl/image/winograd_conv2d.cc @@ -14,8 +14,8 @@ #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/op_context.h" -#include "mace/ops/activation.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/activation_type.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/opencl/helper.h" #include "mace/utils/utils.h" diff --git a/mace/ops/opencl/pooling.h b/mace/ops/opencl/pooling.h index 411efcbbca3d7e227e7413f9c09f0b08fa14fbf6..78628593f98209b7ab2ec3898e24bf370f573268 100644 --- a/mace/ops/opencl/pooling.h +++ b/mace/ops/opencl/pooling.h @@ -18,7 +18,7 @@ #include #include "mace/ops/pooling.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" namespace mace { diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index ed0ae8f291b597b294b5147e366435f34b9dfed1..07cbad06bdb57381ca3befada4baf1e1f11b5bed 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -36,6 +36,7 @@ #include "mace/public/mace.h" #include "mace/utils/utils.h" #include "mace/utils/quantize.h" +#include "mace/ops/testing/test_utils.h" namespace mace { namespace ops { @@ -422,230 +423,6 @@ class OpsTestBase : public ::testing::Test { } }; -template -void GenerateRandomRealTypeData(const std::vector &shape, - std::vector *res, - bool positive = true) { - MACE_CHECK_NOTNULL(res); - - std::random_device rd; - std::mt19937 gen(rd()); - std::normal_distribution nd(0, 1); - - index_t size = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies()); - res->resize(size); - - if (DataTypeToEnum::value == DT_HALF) { - std::generate(res->begin(), res->end(), [&gen, &nd, positive] { - return half_float::half_cast(positive ? std::abs(nd(gen)) - : nd(gen)); - }); - } else { - std::generate(res->begin(), res->end(), [&gen, &nd, positive] { - return positive ? std::abs(nd(gen)) : nd(gen); - }); - } -} - -template -void GenerateRandomIntTypeData(const std::vector &shape, - std::vector *res, - const T a = 0, - const T b = std::numeric_limits::max()) { - MACE_CHECK_NOTNULL(res); - - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> nd(a, b); - - index_t size = std::accumulate(shape.begin(), shape.end(), 1, - std::multiplies()); - res->resize(size); - - std::generate(res->begin(), res->end(), [&gen, &nd] { return nd(gen); }); -} - -template -std::vector VectorStaticCast(const std::vector &&src) { - std::vector dest; - dest.reserve(src.size()); - for (float f : src) { - dest.push_back(static_cast(f)); - } - return std::move(dest); -} - -inline bool IsSameSize(const Tensor &x, const Tensor &y) { - if (x.dim_size() != y.dim_size()) return false; - for (int d = 0; d < x.dim_size(); ++d) { - if (x.dim(d) != y.dim(d)) return false; - } - return true; -} - -inline std::string ShapeToString(const Tensor &x) { - std::stringstream stream; - for (int i = 0; i < x.dim_size(); i++) { - if (i > 0) stream << ","; - int64_t dim = x.dim(i); - if (dim < 0) { - stream << "?"; - } else { - stream << dim; - } - } - stream << "]"; - return std::string(stream.str()); -} - -template -struct is_floating_point_type { - static const bool value = std::is_same::value || - std::is_same::value || - std::is_same::value; -}; - -template -inline void ExpectEqual(const T &a, const T &b) { - EXPECT_EQ(a, b); -} - -template <> -inline void ExpectEqual(const float &a, const float &b) { - EXPECT_FLOAT_EQ(a, b); -} - -template <> -inline void ExpectEqual(const double &a, const double &b) { - EXPECT_DOUBLE_EQ(a, b); -} - -inline void AssertSameDims(const Tensor &x, const Tensor &y) { - ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs " - << "y.shape [ " << ShapeToString(y) << "]"; -} - -template ::value> -struct Expector; - -// Partial specialization for float and double. -template -struct Expector { - static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); } - - static void Equal(const Tensor &x, const Tensor &y) { - ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); - ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); - AssertSameDims(x, y); - Tensor::MappingGuard x_mapper(&x); - Tensor::MappingGuard y_mapper(&y); - auto a = x.data(); - auto b = y.data(); - for (int i = 0; i < x.size(); ++i) { - ExpectEqual(a[i], b[i]); - } - } - - static void Near(const Tensor &x, - const Tensor &y, - const double rel_err, - const double abs_err) { - ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); - ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); - AssertSameDims(x, y); - Tensor::MappingGuard x_mapper(&x); - Tensor::MappingGuard y_mapper(&y); - auto a = x.data(); - auto b = y.data(); - if (x.dim_size() == 4) { - for (int n = 0; n < x.dim(0); ++n) { - for (int h = 0; h < x.dim(1); ++h) { - for (int w = 0; w < x.dim(2); ++w) { - for (int c = 0; c < x.dim(3); ++c) { - const double error = abs_err + rel_err * std::abs(*a); - EXPECT_NEAR(*a, *b, error) << "with index = [" << n << ", " << h - << ", " << w << ", " << c << "]"; - a++; - b++; - } - } - } - } - } else { - for (int i = 0; i < x.size(); ++i) { - const double error = abs_err + rel_err * std::abs(a[i]); - EXPECT_NEAR(a[i], b[i], error) << "a = " << a << " b = " << b - << " index = " << i; - } - } - } -}; - -template -struct Expector { - static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); } - - static void Equal(const Tensor &x, const Tensor &y) { - ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); - ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); - AssertSameDims(x, y); - Tensor::MappingGuard x_mapper(&x); - Tensor::MappingGuard y_mapper(&y); - auto a = x.data(); - auto b = y.data(); - for (int i = 0; i < x.size(); ++i) { - ExpectEqual(a[i], b[i]); - } - } - - static void Near(const Tensor &x, - const Tensor &y, - const double rel_err, - const double abs_err) { - MACE_UNUSED(rel_err); - MACE_UNUSED(abs_err); - Equal(x, y); - } -}; - -template -void ExpectTensorNear(const Tensor &x, - const Tensor &y, - const double rel_err = 1e-5, - const double abs_err = 1e-8) { - Expector::Near(x, y, rel_err, abs_err); -} - -template -void ExpectTensorNear(const Tensor &x, - const Tensor &y, - const double rel_err = 1e-5, - const double abs_err = 1e-8) { - Expector::Near(x, y, rel_err, abs_err); -} - -template -void ExpectTensorSimilar(const Tensor &x, - const Tensor &y, - const double abs_err = 1e-5) { - AssertSameDims(x, y); - Tensor::MappingGuard x_mapper(&x); - Tensor::MappingGuard y_mapper(&y); - auto x_data = x.data(); - auto y_data = y.data(); - double dot_product = 0.0, x_norm = 0.0, y_norm = 0.0; - for (index_t i = 0; i < x.size(); i++) { - dot_product += x_data[i] * y_data[i]; - x_norm += x_data[i] * x_data[i]; - y_norm += y_data[i] * y_data[i]; - } - double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm)); - EXPECT_NEAR(1.0, similarity, abs_err); -} - } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 8691978f6d57573b92a082c34150c064aab9799c..8fd87cdfa38771a56636fd7bd54894ea1cbe042e 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -27,7 +27,7 @@ #include "mace/core/operator.h" #include "mace/core/tensor.h" #include "mace/ops/conv_pool_2d_base.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #ifdef MACE_ENABLE_OPENCL #include "mace/ops/opencl/image/pooling.h" #include "mace/ops/opencl/buffer/pooling.h" diff --git a/mace/ops/pooling_benchmark.cc b/mace/ops/pooling_benchmark.cc index 7189f48d1b9a6b4256f99c4bec93bc157300eff4..a8b6458c8df4a25cb37cf339248a2e9b9a4ad28a 100644 --- a/mace/ops/pooling_benchmark.cc +++ b/mace/ops/pooling_benchmark.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "mace/core/testing/test_benchmark.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/pooling.h" #include "mace/ops/ops_test_util.h" diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index bf6317bd657cef9ff43540a5c115224135cf5874..f9a83a23e41ef290fde9d8005bcf8419a2b217ea 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -15,7 +15,7 @@ #include #include "mace/ops/pooling.h" -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" namespace mace { diff --git a/mace/ops/ref/gemv.cc b/mace/ops/ref/gemv.cc new file mode 100644 index 0000000000000000000000000000000000000000..555c99e27f655ff5f1d77ee5a76bf39c28a52f58 --- /dev/null +++ b/mace/ops/ref/gemv.cc @@ -0,0 +1,161 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#include "mace/ops/ref/gemv.h" + +#if defined(MACE_ENABLE_QUANTIZE) +#include "mace/utils/quantize.h" +#endif // MACE_ENABLE_QUANTIZE + +namespace mace { +namespace ops { +namespace ref { + +MaceStatus Gemv::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + const float *lhs_data = lhs->data(); + const float *rhs_data = rhs->data(); + const float *bias_data = nullptr; + if (bias) { + bias_data = bias->data(); + } + + float *output_data = output->mutable_data(); + + for (index_t b = 0; b < batch; ++b) { + for (index_t h = 0; h < lhs_height; ++h) { + float sum = bias ? bias_data[h] : 0; + for (index_t w = 0; w < lhs_width; ++w) { + sum += lhs_data[ + static_cast(lhs_batched) * b * lhs_height * lhs_width + + h * lhs_width + w] + * rhs_data[b * lhs_width + w]; + } // w + + output_data[b * lhs_height + h] = sum; + } // h + } // b + + return MaceStatus::MACE_SUCCESS; +} + +#if defined(MACE_ENABLE_QUANTIZE) +MaceStatus Gemv::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + const uint8_t *lhs_data = lhs->data(); + const uint8_t *rhs_data = rhs->data(); + const int32_t *bias_data = nullptr; + if (bias) { + bias_data = bias->data(); + } + + uint8_t *output_data = output->mutable_data(); + + MACE_CHECK(output->scale() > 0, "output scale must not be zero"); + const float + output_multiplier_float = lhs->scale() * rhs->scale() / output->scale(); + int32_t lhs_zero = lhs->zero_point(); + int32_t rhs_zero = rhs->zero_point(); + + for (index_t b = 0; b < batch; ++b) { + for (index_t h = 0; h < lhs_height; ++h) { + int32_t sum = bias ? bias_data[h] : 0; + for (index_t w = 0; w < lhs_width; ++w) { + sum += (lhs_data[ + static_cast(lhs_batched) * b * lhs_height * lhs_width + + h * lhs_width + w] - lhs_zero) + * (rhs_data[b * lhs_width + w] - rhs_zero); + } // w + + output_data[b * lhs_height + h] = + Saturate(std::roundf(sum * output_multiplier_float)); + } // h + } // b + return MaceStatus::MACE_SUCCESS; +} + +MaceStatus Gemv::Compute(const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output) { + MACE_UNUSED(context); + + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + const uint8_t *lhs_data = lhs->data(); + const uint8_t *rhs_data = rhs->data(); + const int32_t *bias_data = nullptr; + if (bias) { + bias_data = bias->data(); + } + + int32_t *output_data = output->mutable_data(); + + int32_t lhs_zero = lhs->zero_point(); + int32_t rhs_zero = rhs->zero_point(); + + for (index_t b = 0; b < batch; ++b) { + for (index_t h = 0; h < lhs_height; ++h) { + int32_t sum = bias ? bias_data[h] : 0; + for (index_t w = 0; w < lhs_width; ++w) { + sum += (lhs_data[ + static_cast(lhs_batched) * b * lhs_height * lhs_width + + h * lhs_width + w] - lhs_zero) + * (rhs_data[b * lhs_width + w] - rhs_zero); + } // w + + output_data[b * lhs_height + h] = sum; + } // h + } // b + return MaceStatus::MACE_SUCCESS; +} +#endif // MACE_ENABLE_QUANTIZE + +} // namespace ref +} // namespace ops +} // namespace mace diff --git a/mace/ops/ref/gemv.h b/mace/ops/ref/gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..46892839b395bdce9d7a9fb5312567b41fd1866f --- /dev/null +++ b/mace/ops/ref/gemv.h @@ -0,0 +1,106 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef MACE_OPS_REF_GEMV_H_ +#define MACE_OPS_REF_GEMV_H_ + +#include "mace/public/mace.h" +#include "mace/core/tensor.h" +#include "mace/core/op_context.h" + +namespace mace { +namespace ops { +namespace ref { + +template +class Gemv { + public: + Gemv() {} + ~Gemv() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output); +}; + +template<> +class Gemv { + public: + Gemv() {} + ~Gemv() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output); +}; + +#if defined(MACE_ENABLE_QUANTIZE) +template<> +class Gemv { + public: + Gemv() {} + ~Gemv() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output); +}; + +template<> +class Gemv { + public: + Gemv() {} + ~Gemv() {} + // Always row-major after transpose + MaceStatus Compute( + const OpContext *context, + const Tensor *lhs, + const Tensor *rhs, + const Tensor *bias, + const index_t batch, + const index_t lhs_height, + const index_t lhs_width, + const bool lhs_batched, + Tensor *output); +}; +#endif // MACE_ENABLE_QUANTIZE + +} // namespace ref +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_REF_GEMV_H_ + diff --git a/mace/ops/testing/test_utils.h b/mace/ops/testing/test_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..852d1b3a2edfc6880f0b4b0ae5fc1f27dcabf97f --- /dev/null +++ b/mace/ops/testing/test_utils.h @@ -0,0 +1,306 @@ +// Copyright 2019 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +#ifndef MACE_OPS_TESTING_TEST_UTILS_H_ +#define MACE_OPS_TESTING_TEST_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mace/core/tensor.h" + +namespace mace { +namespace ops { +namespace test { + +template +void GenerateRandomRealTypeData(const std::vector &shape, + T *res, + bool positive = true) { + MACE_CHECK_NOTNULL(res); + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution nd(0, 1); + + index_t size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + + if (DataTypeToEnum::value == DT_HALF) { + std::generate(res, res + size, [&gen, &nd, positive] { + return half_float::half_cast(positive ? std::abs(nd(gen)) + : nd(gen)); + }); + } else { + std::generate(res, res + size, [&gen, &nd, positive] { + return positive ? std::abs(nd(gen)) : nd(gen); + }); + } +} + +template +void GenerateRandomRealTypeData(const std::vector &shape, + std::vector *res, + bool positive = true) { + MACE_CHECK_NOTNULL(res); + + std::random_device rd; + std::mt19937 gen(rd()); + std::normal_distribution nd(0, 1); + + index_t size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + res->resize(size); + + if (DataTypeToEnum::value == DT_HALF) { + std::generate(res->begin(), res->end(), [&gen, &nd, positive] { + return half_float::half_cast(positive ? std::abs(nd(gen)) + : nd(gen)); + }); + } else { + std::generate(res->begin(), res->end(), [&gen, &nd, positive] { + return positive ? std::abs(nd(gen)) : nd(gen); + }); + } +} + +template +void GenerateRandomIntTypeData(const std::vector &shape, + T *res, + const T a = 0, + const T b = std::numeric_limits::max()) { + MACE_CHECK_NOTNULL(res); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> nd(a, b); + + index_t size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + + std::generate(res, res + size, [&gen, &nd] { return nd(gen); }); +} + +template +void GenerateRandomIntTypeData(const std::vector &shape, + std::vector *res, + const T a = 0, + const T b = std::numeric_limits::max()) { + MACE_CHECK_NOTNULL(res); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> nd(a, b); + + index_t size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + res->resize(size); + + std::generate(res->begin(), res->end(), [&gen, &nd] { return nd(gen); }); +} + +template +std::vector VectorStaticCast(const std::vector &&src) { + std::vector dest; + dest.reserve(src.size()); + for (float f : src) { + dest.push_back(static_cast(f)); + } + return std::move(dest); +} + +inline bool IsSameSize(const Tensor &x, const Tensor &y) { + if (x.dim_size() != y.dim_size()) return false; + for (int d = 0; d < x.dim_size(); ++d) { + if (x.dim(d) != y.dim(d)) return false; + } + return true; +} + +inline std::string ShapeToString(const Tensor &x) { + std::stringstream stream; + for (int i = 0; i < x.dim_size(); i++) { + if (i > 0) stream << ","; + int64_t dim = x.dim(i); + if (dim < 0) { + stream << "?"; + } else { + stream << dim; + } + } + stream << "]"; + return std::string(stream.str()); +} + +template +struct is_floating_point_type { + static const bool value = std::is_same::value || + std::is_same::value || + std::is_same::value; +}; + +template +inline void ExpectEqual(const T &a, const T &b) { + EXPECT_EQ(a, b); +} + +template<> +inline void ExpectEqual(const float &a, const float &b) { + EXPECT_FLOAT_EQ(a, b); +} + +template<> +inline void ExpectEqual(const double &a, const double &b) { + EXPECT_DOUBLE_EQ(a, b); +} + +inline void AssertSameDims(const Tensor &x, const Tensor &y) { + ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs " + << "y.shape [ " << ShapeToString(y) << "]"; +} + +template::value> +struct Expector; + +// Partial specialization for float and double. +template +struct Expector { + static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); } + + static void Equal(const Tensor &x, const Tensor &y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto a = x.data(); + auto b = y.data(); + for (int i = 0; i < x.size(); ++i) { + ExpectEqual(a[i], b[i]); + } + } + + static void Near(const Tensor &x, + const Tensor &y, + const double rel_err, + const double abs_err) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto a = x.data(); + auto b = y.data(); + if (x.dim_size() == 4) { + for (int n = 0; n < x.dim(0); ++n) { + for (int h = 0; h < x.dim(1); ++h) { + for (int w = 0; w < x.dim(2); ++w) { + for (int c = 0; c < x.dim(3); ++c) { + const double error = abs_err + rel_err * std::abs(*a); + EXPECT_NEAR(*a, *b, error) << "with index = [" << n << ", " << h + << ", " << w << ", " << c << "]"; + a++; + b++; + } + } + } + } + } else { + for (int i = 0; i < x.size(); ++i) { + const double error = abs_err + rel_err * std::abs(a[i]); + EXPECT_NEAR(a[i], b[i], error) << "a = " << a << " b = " << b + << " index = " << i; + } + } + } +}; + +template +struct Expector { + static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); } + + static void Equal(const Tensor &x, const Tensor &y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto a = x.data(); + auto b = y.data(); + for (int i = 0; i < x.size(); ++i) { + ExpectEqual(a[i], b[i]); + } + } + + static void Near(const Tensor &x, + const Tensor &y, + const double rel_err, + const double abs_err) { + MACE_UNUSED(rel_err); + MACE_UNUSED(abs_err); + Equal(x, y); + } +}; + +template +void ExpectTensorNear(const Tensor &x, + const Tensor &y, + const double rel_err = 1e-5, + const double abs_err = 1e-8) { + Expector::Near(x, y, rel_err, abs_err); +} + +template +void ExpectTensorNear(const Tensor &x, + const Tensor &y, + const double rel_err = 1e-5, + const double abs_err = 1e-8) { + Expector::Near(x, y, rel_err, abs_err); +} + +template +void ExpectTensorSimilar(const Tensor &x, + const Tensor &y, + const double abs_err = 1e-5) { + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto x_data = x.data(); + auto y_data = y.data(); + double dot_product = 0.0, x_norm = 0.0, y_norm = 0.0; + for (index_t i = 0; i < x.size(); i++) { + dot_product += x_data[i] * y_data[i]; + x_norm += x_data[i] * x_data[i]; + y_norm += y_data[i] * y_data[i]; + } + double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm)); + EXPECT_NEAR(1.0, similarity, abs_err); +} + +} // namespace test +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_TESTING_TEST_UTILS_H_ + diff --git a/mace/ops/transpose.cc b/mace/ops/transpose.cc index a85351d0a4aa468710bf6c0a1e7e9bd558308495..678f3ee642f210083904c189fc4752dcd8c5bd4e 100644 --- a/mace/ops/transpose.cc +++ b/mace/ops/transpose.cc @@ -21,199 +21,11 @@ #include #include "mace/core/operator.h" -#include "mace/ops/transpose.h" +#include "mace/ops/common/transpose.h" namespace mace { namespace ops { -namespace { -void TransposeNHWCToNCHWC3(const float *input, - float *output, - const index_t height, - const index_t width) { - index_t image_size = height * width; - -#pragma omp parallel for - for (index_t h = 0; h < height; ++h) { - index_t in_offset = h * width * 3; - index_t out_offset = h * width; - -#if defined(MACE_ENABLE_NEON) - index_t w; - for (w = 0; w + 3 < width; w += 4) { - float32x4x3_t vi = vld3q_f32(input + in_offset); - vst1q_f32(output + out_offset, vi.val[0]); - vst1q_f32(output + out_offset + image_size, vi.val[1]); - vst1q_f32(output + out_offset + image_size * 2, vi.val[2]); - - in_offset += 12; - out_offset += 4; - } - for (; w < width; ++w) { - for (index_t c = 0; c < 3; ++c) { - output[h * width + image_size * c + w] = - input[h * width * 3 + w * 3 + c]; - } - } -#else - for (index_t w = 0; w < width; ++w) { - for (index_t c = 0; c < 3; ++c) { - output[out_offset + c * image_size + w] = input[in_offset + w * 3 + c]; - } - } -#endif - } -} - -void TransposeNCHWToNHWCC2(const float *input, - float *output, - const index_t height, - const index_t width) { - index_t image_size = height * width; -#pragma omp parallel for - for (index_t h = 0; h < height; ++h) { - index_t in_offset = h * width; - index_t out_offset = h * width * 2; - -#if defined(MACE_ENABLE_NEON) - index_t w; - for (w = 0; w + 3 < width; w += 4) { - float32x4_t vi0 = vld1q_f32(input + in_offset); - float32x4_t vi1 = vld1q_f32(input + in_offset + image_size); - float32x4x2_t vi = {vi0, vi1}; - vst2q_f32(output + out_offset, vi); - in_offset += 4; - out_offset += 8; - } - for (; w < width; ++w) { - for (index_t c = 0; c < 2; ++c) { - output[h * width * 2 + w * 2 + c] = - input[h * width + image_size * c + w]; - } - } -#else - for (index_t w = 0; w < width; ++w) { - for (index_t c = 0; c < 2; ++c) { - output[out_offset + w * 2 + c] = input[in_offset + c * image_size + w]; - } - } -#endif - } -} -} // namespace - -MaceStatus Transpose(const float *input, - const std::vector &input_shape, - const std::vector &dst_dims, - float *output) { - MACE_CHECK((input_shape.size() == 2 && dst_dims.size() == 2) || - (input_shape.size() == 4 && dst_dims.size() == 4), - "Only support 2D or 4D transpose"); - - std::vector output_shape; - for (size_t i = 0; i < dst_dims.size(); ++i) { - output_shape.push_back(input_shape[dst_dims[i]]); - } - - if (input_shape.size() == 2) { - MACE_CHECK(dst_dims[0] == 1 && dst_dims[1] == 0, "no need transform"); - index_t height = input_shape[0]; - index_t width = input_shape[1]; - index_t stride_i = height; - index_t stride_j = width; - index_t tile_size = height > 512 || width > 512 ? 64 : 32; -#pragma omp parallel for collapse(2) - for (index_t i = 0; i < height; i += tile_size) { - for (index_t j = 0; j < width; j += tile_size) { - index_t end_i = std::min(i + tile_size, height); - index_t end_j = std::min(j + tile_size, width); - for (index_t tile_i = i; tile_i < end_i; ++tile_i) { - for (index_t tile_j = j; tile_j < end_j; ++tile_j) { - output[tile_j * stride_i + tile_i] = - input[tile_i * stride_j + tile_j]; - } - } - } - } - } else if (input_shape.size() == 4) { - std::vector transpose_order_from_NHWC_to_NCHW{0, 3, 1, 2}; - std::vector transpose_order_from_NCHW_to_NHWC{0, 2, 3, 1}; - index_t batch_size = input_shape[1] * input_shape[2] * input_shape[3]; - - if (dst_dims == transpose_order_from_NHWC_to_NCHW && input_shape[3] == 3) { - for (index_t b = 0; b < input_shape[0]; ++b) { - TransposeNHWCToNCHWC3(input + b * batch_size, - output + b * batch_size, - input_shape[1], - input_shape[2]); - } - } else if (dst_dims == transpose_order_from_NCHW_to_NHWC - && input_shape[1] == 2) { - for (index_t b = 0; b < input_shape[0]; ++b) { - TransposeNCHWToNHWCC2(input + b * batch_size, - output + b * batch_size, - input_shape[2], - input_shape[3]); - } - } else if (dst_dims == std::vector{0, 2, 1, 3}) { - index_t height = input_shape[1]; - index_t width = input_shape[2]; - index_t channel = input_shape[3]; - index_t channel_raw_size = channel * sizeof(float); - index_t stride_i = height; - index_t stride_j = width; - index_t tile_size = std::max(static_cast(1), - static_cast(std::sqrt( - 8 * 1024 / channel))); -#pragma omp parallel for collapse(2) - for (index_t i = 0; i < height; i += tile_size) { - for (index_t j = 0; j < width; j += tile_size) { - index_t end_i = std::min(i + tile_size, height); - index_t end_j = std::min(j + tile_size, width); - for (index_t tile_i = i; tile_i < end_i; ++tile_i) { - for (index_t tile_j = j; tile_j < end_j; ++tile_j) { - memcpy(output + (tile_j * stride_i + tile_i) * channel, - input + (tile_i * stride_j + tile_j) * channel, - channel_raw_size); - } - } - } - } - } else { - std::vector - in_stride{input_shape[1] * input_shape[2] * input_shape[3], - input_shape[2] * input_shape[3], input_shape[3], 1}; - std::vector - out_stride{output_shape[1] * output_shape[2] * output_shape[3], - output_shape[2] * output_shape[3], output_shape[3], 1}; - - std::vector idim(4, 0); - std::vector odim(4, 0); - for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { - for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { - for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { - for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { - idim[dst_dims[0]] = odim[0]; - idim[dst_dims[1]] = odim[1]; - idim[dst_dims[2]] = odim[2]; - idim[dst_dims[3]] = odim[3]; - - output[odim[0] * out_stride[0] + odim[1] * out_stride[1] - + odim[2] * out_stride[2] + odim[3]] = - input[idim[0] * in_stride[0] + idim[1] * in_stride[1] - + idim[2] * in_stride[2] + idim[3]]; - } - } - } - } - } - } else { - MACE_NOT_IMPLEMENTED; - } - - return MaceStatus::MACE_SUCCESS; -} - template class TransposeOp; diff --git a/mace/python/tools/quantization/quantize_stat.py b/mace/python/tools/quantization/quantize_stat.py index f9a8656d8790ce77befc25c98103f4d4a2196d46..31fd110f4a197ed47a4a1d1d57f833290b181f6b 100644 --- a/mace/python/tools/quantization/quantize_stat.py +++ b/mace/python/tools/quantization/quantize_stat.py @@ -31,6 +31,13 @@ class QuantizeStat(object): if not enhance or samples <= 1: res[tensor_name] = (tensor_min, tensor_max) else: + """ + Enhancement mode: + This policy eliminates outliers that cause long-tail + statistical range. We try to reduce as much range as it could + while retaining more samples. d(range)/d(sample_quantile) is + used to measure this qualitatively. + """ tensor_mins = np.sort(tensor_ranges[tensor_name][0]) tensor_maxs = np.sort(tensor_ranges[tensor_name][1])[::-1] cur_min_idx = 0 diff --git a/mace/test/mace_api_test.h b/mace/test/mace_api_test.h index 5d954b1755016e6bb62a1b6bc0a0d8be678d051a..2c2ed7d177fb2b1d834f427a5ecfaa956fe7e648 100644 --- a/mace/test/mace_api_test.h +++ b/mace/test/mace_api_test.h @@ -21,7 +21,7 @@ #include #include -#include "mace/ops/conv_pool_2d_util.h" +#include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/ops_test_util.h" #include "mace/public/mace.h"