From 9fe6761aaf4a2c2ace95885310d7e473fe708cb4 Mon Sep 17 00:00:00 2001 From: Bin Li Date: Fri, 25 Sep 2020 15:51:20 +0800 Subject: [PATCH] feat: Add bf16 kernels for MobileNet --- docs/user_guide/advanced_usage.rst | 3 +- mace/core/bfloat16.h | 57 +- mace/ops/BUILD.bazel | 12 +- mace/ops/CMakeLists.txt | 6 + mace/ops/arm/base/activation.cc | 156 ++++ mace/ops/arm/base/bias_add.cc | 100 +++ mace/ops/arm/base/common_neon.h | 348 +++++++++ mace/ops/arm/base/conv_2d_1x1.cc | 5 + mace/ops/arm/base/conv_2d_3x3.cc | 9 + .../{fp32 => base}/conv_2d_3x3_winograd.cc | 174 ++--- .../arm/{fp32 => base}/conv_2d_3x3_winograd.h | 37 +- mace/ops/arm/base/conv_2d_general.cc | 147 ++++ mace/ops/arm/base/depthwise_conv_2d_3x3.cc | 437 +++++++++++ mace/ops/arm/base/gemm.cc | 679 +++++++++++++++++ mace/ops/arm/base/gemm.h | 48 +- mace/ops/arm/bf16/conv_2d_3x3.cc | 459 ++++++++++++ mace/ops/arm/bf16/gemm.cc | 535 ++++++++++++++ mace/ops/arm/fp32/activation.cc | 162 ----- mace/ops/arm/fp32/bias_add.cc | 123 ---- mace/ops/arm/fp32/common_neon.h | 70 -- mace/ops/arm/fp32/conv_2d_general.cc | 197 ----- mace/ops/arm/fp32/deconv_2d_2x2.cc | 2 +- mace/ops/arm/fp32/deconv_2d_3x3.cc | 2 +- mace/ops/arm/fp32/deconv_2d_4x4.cc | 2 +- mace/ops/arm/fp32/depthwise_conv_2d_3x3.cc | 428 ----------- mace/ops/arm/fp32/depthwise_deconv_2d_3x3.cc | 2 +- mace/ops/arm/fp32/depthwise_deconv_2d_4x4.cc | 2 +- mace/ops/arm/fp32/gemm.cc | 688 ------------------ mace/ops/registry/op_delegators_registry.cc | 4 +- test/ccbenchmark/BUILD.bazel | 3 + .../mace/ops/activation_benchmark.cc | 101 ++- .../mace/ops/bias_add_benchmark.cc | 37 +- .../ccbenchmark/mace/ops/conv_2d_benchmark.cc | 37 +- .../mace/ops/depthwise_conv2d_benchmark.cc | 34 +- test/ccunit/BUILD.bazel | 15 +- test/ccunit/mace/ops/activation_test.cc | 55 ++ test/ccunit/mace/ops/arm/bf16/gemm_test.cc | 106 +++ test/ccunit/mace/ops/bias_add_test.cc | 44 ++ test/ccunit/mace/ops/conv_2d_test.cc | 68 ++ test/ccunit/mace/ops/depthwise_conv2d_test.cc | 65 ++ test/ccutils/mace/ops/ops_test_util.h | 16 + tools/bazel_adb_run.py | 6 + tools/bazel_build_standalone_lib.sh | 17 +- 43 files changed, 3580 insertions(+), 1918 deletions(-) create mode 100644 mace/ops/arm/base/common_neon.h rename mace/ops/arm/{fp32 => base}/conv_2d_3x3_winograd.cc (83%) rename mace/ops/arm/{fp32 => base}/conv_2d_3x3_winograd.h (78%) create mode 100644 mace/ops/arm/bf16/conv_2d_3x3.cc create mode 100644 mace/ops/arm/bf16/gemm.cc delete mode 100644 mace/ops/arm/fp32/activation.cc delete mode 100644 mace/ops/arm/fp32/bias_add.cc delete mode 100644 mace/ops/arm/fp32/common_neon.h delete mode 100644 mace/ops/arm/fp32/conv_2d_general.cc delete mode 100644 mace/ops/arm/fp32/depthwise_conv_2d_3x3.cc create mode 100644 test/ccunit/mace/ops/arm/bf16/gemm_test.cc diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index 4e34309e..8d8c50df 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -518,7 +518,8 @@ Use ``-h`` to get detailed help. Reduce Library Size ------------------- -* Build for your own usage purpose. +* Build for your own usage purpose. Some configuration variables in tools/bazel_build_standalone_lib.sh + are set to ``true`` by default, you can change them to ``false`` to reduce the library size. * **dynamic library** - If the models don't need to run on device ``dsp``, change the build option ``--define hexagon=true`` diff --git a/mace/core/bfloat16.h b/mace/core/bfloat16.h index 21f8ae0b..912674af 100644 --- a/mace/core/bfloat16.h +++ b/mace/core/bfloat16.h @@ -62,27 +62,27 @@ class BFloat16 { } template - BFloat16 operator+(T value) const { - return BFloat16(Sphinx( - static_cast(data_ << 16)).f + static_cast(value)); + float operator+(T value) const { + return Sphinx(static_cast(data_ << 16)).f + + static_cast(value); } template - BFloat16 operator-(T value) const { - return BFloat16(Sphinx( - static_cast(data_ << 16)).f - static_cast(value)); + float operator-(T value) const { + return Sphinx(static_cast(data_ << 16)).f + - static_cast(value); } template - BFloat16 operator*(T value) const { - return BFloat16(Sphinx( - static_cast(data_ << 16)).f * static_cast(value)); + float operator*(T value) const { + return Sphinx(static_cast(data_ << 16)).f + * static_cast(value); } template - BFloat16 operator/(T value) const { - return BFloat16(Sphinx( - static_cast(data_ << 16)).f / static_cast(value)); + float operator/(T value) const { + return Sphinx(static_cast(data_ << 16)).f + / static_cast(value); } template @@ -223,7 +223,6 @@ inline ostream &operator<<(ostream &ss, // NOLINT } // namespace std - inline float operator+(const float &a, const mace::BFloat16 &value) { return a + static_cast(value); } @@ -256,6 +255,38 @@ inline void operator/=(float &a, const mace::BFloat16 &value) { // NOLINT a /= static_cast(value); } +inline double operator+(const double &a, const mace::BFloat16 &value) { + return a + static_cast(value); +} + +inline double operator-(const double &a, const mace::BFloat16 &value) { + return a - static_cast(value); +} + +inline double operator*(const double &a, const mace::BFloat16 &value) { + return a * static_cast(value); +} + +inline double operator/(const double &a, const mace::BFloat16 &value) { + return a / static_cast(value); +} + +inline void operator+=(double &a, const mace::BFloat16 &value) { // NOLINT + a += static_cast(value); +} + +inline void operator-=(double &a, const mace::BFloat16 &value) { // NOLINT + a -= static_cast(value); +} + +inline void operator*=(double &a, const mace::BFloat16 &value) { // NOLINT + a *= static_cast(value); +} + +inline void operator/=(double &a, const mace::BFloat16 &value) { // NOLINT + a /= static_cast(value); +} + #endif // MACE_ENABLE_BFLOAT16 #endif // MACE_CORE_BFLOAT16_H_ diff --git a/mace/ops/BUILD.bazel b/mace/ops/BUILD.bazel index 39d954f8..73a3cc20 100644 --- a/mace/ops/BUILD.bazel +++ b/mace/ops/BUILD.bazel @@ -104,15 +104,13 @@ cc_library( "arm/fp32/*.cc", "arm/fp16/gemv.h", ], - exclude = [ - "arm/fp32/*_test.cc", - ], ) + if_quantize_enabled(glob( [ "arm/q8/*.cc", ], - exclude = [ - "arm/q8/*_test.cc", + )) + if_bfloat16_enabled(glob( + [ + "arm/bf16/*.cc", ], )), hdrs = glob( @@ -124,6 +122,10 @@ cc_library( [ "arm/q8/*.h", ], + )) + if_bfloat16_enabled(glob( + [ + "arm/bf16/*.h", + ], )), copts = [ "-Werror", diff --git a/mace/ops/CMakeLists.txt b/mace/ops/CMakeLists.txt index 61b3b153..c2f18a70 100644 --- a/mace/ops/CMakeLists.txt +++ b/mace/ops/CMakeLists.txt @@ -11,6 +11,9 @@ file(GLOB OPS_ARM_NEON_BASE_KERNELS_SRCS file(GLOB OPS_ARM_NEON_FP32_KERNELS_SRCS arm/fp32/*.cc ) +file(GLOB OPS_ARM_NEON_BF16_KERNELS_SRCS + arm/bf16/*.cc +) file(GLOB OPS_ARM_NEON_Q8_KERNELS_SRCS arm/q8/*.cc ) @@ -39,6 +42,9 @@ if(MACE_ENABLE_NEON) if(MACE_ENABLE_QUANTIZE) set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_Q8_KERNELS_SRCS}) endif(MACE_ENABLE_QUANTIZE) + if(MACE_ENABLE_BFLOAT16) + set(OPS_SRCS ${OPS_SRCS} ${OPS_ARM_NEON_BF16_KERNELS_SRCS}) + endif(MACE_ENABLE_BFLOAT16) endif(MACE_ENABLE_NEON) if(MACE_ENABLE_OPENCL) diff --git a/mace/ops/arm/base/activation.cc b/mace/ops/arm/base/activation.cc index 90135b17..ab9a5336 100644 --- a/mace/ops/arm/base/activation.cc +++ b/mace/ops/arm/base/activation.cc @@ -14,10 +14,25 @@ #include "mace/ops/arm/base/activation.h" +#include + +#include "mace/ops/arm/base/common_neon.h" + namespace mace { namespace ops { namespace arm { +extern template void Activation::ActivateRelu( + utils::ThreadPool *, const Tensor *, Tensor *); +extern template void Activation::ActivateRelux( + utils::ThreadPool *, const Tensor *, Tensor *); +extern template void Activation::ActivateLeakyRelu( + utils::ThreadPool *, const Tensor *, Tensor *); +extern template void Activation::ActivateTanh( + utils::ThreadPool *, const Tensor *, Tensor *); +extern template void Activation::ActivateSigmoid( + utils::ThreadPool *, const Tensor *, Tensor *); + template MaceStatus Activation::Compute(const OpContext *context, const Tensor *input, Tensor *output) { @@ -76,15 +91,156 @@ void Activation::DoActivation(const OpContext *context, } } +template +void Activation::ActivateRelu(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + const float32x4_t vzero = vdupq_n_f32(0.f); + const index_t block_count = input_size / 4; + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + const T *input_ptr = input_data + start * 4; + T *output_ptr = output_data + start * 4; + + for (index_t i = start; i < end; i += step) { + float32x4_t v = vld1q(input_ptr); + v = vmaxq_f32(v, vzero); + vst1q(output_ptr, v); + + input_ptr += 4; + output_ptr += 4; + } + }, + 0, block_count, 1); + + // remain + for (index_t i = block_count * 4; i < input_size; ++i) { + output_data[i] = std::max(0.f, input_data[i]); + } +} + +template +void Activation::ActivateRelux(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + const float32x4_t vzero = vdupq_n_f32(0.f); + const float32x4_t vlimit = vdupq_n_f32(limit_); + const index_t block_count = input_size / 4; + + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + auto input_ptr = input_data + start * 4; + auto output_ptr = output_data + start * 4; + + for (index_t i = start; i < end; i += step) { + float32x4_t v = vld1q(input_ptr); + v = vmaxq_f32(v, vzero); + v = vminq_f32(v, vlimit); + vst1q(output_ptr, v); + + input_ptr += 4; + output_ptr += 4; + } + }, + 0, block_count, 1); + + // remain + for (index_t i = block_count * 4; i < input_size; ++i) { + output_data[i] = std::max(0.f, std::min(limit_, input_data[i])); + } +} + +template +void Activation::ActivateLeakyRelu(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + const float32x4_t vzero = vdupq_n_f32(0.f); + const float32x4_t valpha = vdupq_n_f32(leakyrelu_coefficient_); + const index_t block_count = input_size / 4; + + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + auto input_ptr = input_data + start * 4; + auto output_ptr = output_data + start * 4; + + for (index_t i = start; i < end; i += step) { + float32x4_t v = vld1q(input_ptr); + float32x4_t u = vminq_f32(v, vzero); + v = vmaxq_f32(v, vzero); + v = vmlaq_f32(v, valpha, u); + vst1q(output_ptr, v); + + input_ptr += 4; + output_ptr += 4; + } + }, + 0, block_count, 1); + + // remain + for (index_t i = block_count * 4; i < input_size; ++i) { + output_data[i] = std::max(input_data[i], 0.f) + + std::min(input_data[i], 0.f) * leakyrelu_coefficient_; + } +} + +template +void Activation::ActivateTanh(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + for (index_t i = start; i < end; i += step) { + output_data[i] = std::tanh(input_data[i]); + } + }, + 0, input_size, 1); +} + +template +void Activation::ActivateSigmoid(utils::ThreadPool *thread_pool, + const Tensor *input, + Tensor *output) { + const auto input_data = input->data(); + auto output_data = output->mutable_data(); + const index_t input_size = input->size(); + + thread_pool->Compute1D( + [=](index_t start, index_t end, index_t step) { + for (index_t i = start; i < end; i += step) { + output_data[i] = 1 / (1 + std::exp(-(input_data[i]))); + } + }, + 0, input_size, 1); +} + void RegisterActivationDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, Activation, delegator::ActivationParam, MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, float, ImplType::NEON)); + #ifdef MACE_ENABLE_QUANTIZE MACE_REGISTER_DELEGATOR( registry, Activation, delegator::ActivationParam, MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, uint8_t, ImplType::NEON)); #endif // MACE_ENABLE_QUANTIZE + + MACE_REGISTER_BF16_DELEGATOR( + registry, Activation, delegator::ActivationParam, + MACE_DELEGATOR_KEY(Activation, DeviceType::CPU, BFloat16, + ImplType::NEON)); } } // namespace arm diff --git a/mace/ops/arm/base/bias_add.cc b/mace/ops/arm/base/bias_add.cc index 3ae9e416..42526157 100644 --- a/mace/ops/arm/base/bias_add.cc +++ b/mace/ops/arm/base/bias_add.cc @@ -14,10 +14,24 @@ #include "mace/ops/arm/base/bias_add.h" +#include +#include + +#include "mace/ops/arm/base/common_neon.h" + namespace mace { namespace ops { namespace arm { +extern template void BiasAdd::AddBiasNCHW<1>( + utils::ThreadPool *, const Tensor *, const Tensor *, Tensor *); +extern template void BiasAdd::AddBiasNCHW<2>( + utils::ThreadPool *, const Tensor *, const Tensor *, Tensor *); +extern template void BiasAdd::AddBiasNHWC<1>( + utils::ThreadPool *, const Tensor *, const Tensor *, Tensor *); +extern template void BiasAdd::AddBiasNHWC<2>( + utils::ThreadPool *, const Tensor *, const Tensor *, Tensor *); + template MaceStatus BiasAdd::Compute(const OpContext *context, const Tensor *input, @@ -69,15 +83,101 @@ void BiasAdd::AddBias(const OpContext *context, } } +template +template +void BiasAdd::AddBiasNCHW(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output) { + const auto input_data = input->data(); + const auto bias_data = bias->data(); + auto output_data = output->mutable_data(); + + const index_t batch = input->dim(0); + const index_t channels = input->dim(1); + const index_t image_size = input->dim(2) * input->dim(3); + const index_t block_count = image_size / 4; + const index_t remain = image_size % 4; + thread_pool->Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + const index_t b_offset = b * channels; + for (index_t c = start1; c < end1; c += step1) { + const index_t offset = (b_offset + c) * image_size; + auto input_ptr = input_data + offset; + auto output_ptr = output_data + offset; + const float bias = bias_data[bias_index(b_offset, c)]; + float32x4_t vbias = vdupq_n_f32(bias); + + for (index_t i = 0; i < block_count; ++i) { + float32x4_t v = vld1q(input_ptr); + v = vaddq_f32(v, vbias); + vst1q(output_ptr, v); + + input_ptr += 4; + output_ptr += 4; + } + for (index_t i = 0; i < remain; ++i) { + (*output_ptr++) = (*input_ptr++) + bias; + } + } + } + }, + 0, batch, 1, 0, channels, 1); +} + + +template +template +void BiasAdd::AddBiasNHWC(utils::ThreadPool *thread_pool, + const Tensor *input, + const Tensor *bias, + Tensor *output) { + const auto input_ptr = input->data(); + const auto bias_ptr = bias->data(); + auto output_ptr = output->mutable_data(); + + const std::vector &shape = input->shape(); + const index_t channels = *shape.rbegin(); + const auto batch = shape[0]; + if (Dim == 2) { + MACE_CHECK(batch == bias->shape()[0]); + } + const index_t fused_hw = std::accumulate(shape.begin() + 1, shape.end() - 1, + 1, std::multiplies()); + thread_pool->Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t i = start0; i < end0; i += step0) { + auto offset = i * fused_hw; + auto bias_offset = i * channels; + for (index_t j = start1; j < end1; j += step1) { + index_t pos = (offset + j) * channels; + for (index_t c = 0; c < channels; ++c, ++pos) { + output_ptr[pos] = + input_ptr[pos] + bias_ptr[bias_index(bias_offset, c)]; + } + } + } + }, + 0, batch, 1, 0, fused_hw, 1); +} + void RegisterBiasAddDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, BiasAdd, DelegatorParam, MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, float, ImplType::NEON)); + #ifdef MACE_ENABLE_QUANTIZE MACE_REGISTER_DELEGATOR( registry, BiasAdd, DelegatorParam, MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, uint8_t, ImplType::NEON)); #endif // MACE_ENABLE_QUANTIZE + + MACE_REGISTER_BF16_DELEGATOR( + registry, BiasAdd, DelegatorParam, + MACE_DELEGATOR_KEY(BiasAdd, DeviceType::CPU, BFloat16, ImplType::NEON)); } } // namespace arm diff --git a/mace/ops/arm/base/common_neon.h b/mace/ops/arm/base/common_neon.h new file mode 100644 index 00000000..e7bd3180 --- /dev/null +++ b/mace/ops/arm/base/common_neon.h @@ -0,0 +1,348 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARM_BASE_COMMON_NEON_H_ +#define MACE_OPS_ARM_BASE_COMMON_NEON_H_ + +#include + +#include "mace/core/bfloat16.h" + +namespace mace { +namespace ops { +namespace arm { + +typedef struct float32x8_t { + float32x4_t val[2]; +} float32x8_t; + +#if !defined(__aarch64__) +inline float vaddvq_f32(float32x4_t v) { + float32x2_t _sum = vadd_f32(vget_low_f32(v), vget_high_f32(v)); + _sum = vpadd_f32(_sum, _sum); + return vget_lane_f32(_sum, 0); +} +#endif + +inline float32x4_t neon_vfma_lane_0(float32x4_t a, + float32x4_t b, + float32x4_t c) { +#ifdef __aarch64__ + return vfmaq_laneq_f32(a, b, c, 0); +#else + return vmlaq_lane_f32(a, b, vget_low_f32(c), 0); +#endif +} + +inline float32x4_t neon_vfma_lane_1(float32x4_t a, + float32x4_t b, + float32x4_t c) { +#ifdef __aarch64__ + return vfmaq_laneq_f32(a, b, c, 1); +#else + return vmlaq_lane_f32(a, b, vget_low_f32(c), 1); +#endif +} + +inline float32x4_t neon_vfma_lane_2(float32x4_t a, + float32x4_t b, + float32x4_t c) { +#ifdef __aarch64__ + return vfmaq_laneq_f32(a, b, c, 2); +#else + return vmlaq_lane_f32(a, b, vget_high_f32(c), 0); +#endif +} + +inline float32x4_t neon_vfma_lane_3(float32x4_t a, + float32x4_t b, + float32x4_t c) { +#ifdef __aarch64__ + return vfmaq_laneq_f32(a, b, c, 3); +#else + return vmlaq_lane_f32(a, b, vget_high_f32(c), 1); +#endif +} + +inline void neon_vec_left_shift_1(const float32x4_t &src, + float32x4_t *dst) { + (*dst)[0] = src[1]; + (*dst)[1] = src[2]; + (*dst)[2] = src[3]; +} + +inline void neon_vec_left_shift_2(const float32x4_t &src, + float32x4_t *dst) { + (*dst)[0] = src[2]; + (*dst)[1] = src[3]; +} + +inline void neon_vec_left_shift_3(const float32x4_t &src, + float32x4_t *dst) { + (*dst)[0] = src[3]; +} + +inline void neon_vec_right_shift_1(const float32x4_t &src, + float32x4_t *dst) { + (*dst)[1] = src[0]; + (*dst)[2] = src[1]; + (*dst)[3] = src[2]; +} + +inline void neon_vec_right_shift_2(const float32x4_t &src, + float32x4_t *dst) { + (*dst)[2] = src[0]; + (*dst)[3] = src[1]; +} + +inline void neon_vec_right_shift_3(const float32x4_t &src, + float32x4_t *dst) { + (*dst)[3] = src[0]; +} + +inline float32x2_t vld1(const float *ptr) { + return vld1_f32(ptr); +} + +inline void vst1(float *ptr, float32x2_t v) { + vst1_f32(ptr, v); +} + +inline float32x4_t vld1q(const float *ptr) { + return vld1q_f32(ptr); +} + +inline float32x4x2_t vld2q(const float *ptr) { + return vld2q_f32(ptr); +} + +inline float32x4x3_t vld3q(const float *ptr) { + return vld3q_f32(ptr); +} + +inline void vst1q(float *ptr, float32x4_t v) { + vst1q_f32(ptr, v); +} + +inline void vst2q(float *ptr, float32x4x2_t v) { + vst2q_f32(ptr, v); +} + +inline void vst3q(float *ptr, float32x4x3_t v) { + vst3q_f32(ptr, v); +} + +inline float32x8_t vld1o(float *ptr) { + return {vld1q_f32(ptr), vld1q_f32(ptr + 4)}; +} + +inline void vst1o(float *ptr, float32x8_t v) { + vst1q_f32(ptr, v.val[0]); + vst1q_f32(ptr + 4, v.val[1]); +} + +#if defined(MACE_ENABLE_BFLOAT16) + +// load of 2D vector +inline float32x2_t vld1_bf16(const BFloat16 *ptr) { + return (float32x2_t){ptr[0], ptr[1]}; // NOLINT(readability/braces) +} + +inline float32x2_t vld1_bf16(const uint16_t *ptr) { + return vld1_bf16(reinterpret_cast(ptr)); +} + +inline float32x2_t vld1(const BFloat16 *ptr) { + return vld1_bf16(ptr); +} + +inline float32x2_t vld1(const uint16_t *ptr) { + return vld1_bf16(reinterpret_cast(ptr)); +} + +// store of 2D vector +inline void vst1_bf16(BFloat16 *ptr, float32x2_t v) { + ptr[0] = v[0]; + ptr[1] = v[1]; +} + +inline void vst1_bf16(uint16_t *ptr, float32x2_t v) { + vst1_bf16(reinterpret_cast(ptr), v); +} + +inline void vst1(BFloat16 *ptr, float32x2_t v) { + vst1_bf16(ptr, v); +} + +inline void vst1(uint16_t *ptr, float32x2_t v) { + vst1_bf16(reinterpret_cast(ptr), v); +} + +// load of 4D vector +inline float32x4_t vld1q_bf16(const uint16_t *ptr) { + return vreinterpretq_f32_u32(vshll_n_u16(vld1_u16(ptr), 16)); +} + +inline float32x4_t vld1q_bf16(const BFloat16 *ptr) { + return vld1q_bf16(reinterpret_cast(ptr)); +} + +inline float32x4_t vld1q(const uint16_t *ptr) { + return vld1q_bf16(ptr); +} + +inline float32x4_t vld1q(const BFloat16 *ptr) { + return vld1q_bf16(reinterpret_cast(ptr)); +} + +// load of 2 4D vectors and perform de-interleaving +inline float32x4x2_t vld2q_bf16(const uint16_t *ptr) { + uint16x4x2_t u = vld2_u16(ptr); + return {vreinterpretq_f32_u32(vshll_n_u16(u.val[0], 16)), + vreinterpretq_f32_u32(vshll_n_u16(u.val[1], 16))}; +} + +inline float32x4x2_t vld2q_bf16(const BFloat16 *ptr) { + return vld2q_bf16(reinterpret_cast(ptr)); +} + +inline float32x4x2_t vld2q(const uint16_t *ptr) { + return vld2q_bf16(ptr); +} + +inline float32x4x2_t vld2q(const BFloat16 *ptr) { + return vld2q_bf16(reinterpret_cast(ptr)); +} + +// load of 3 4D vectors and perform de-interleaving +inline float32x4x3_t vld3q_bf16(const uint16_t *ptr) { + uint16x4x3_t u = vld3_u16(ptr); + return {vreinterpretq_f32_u32(vshll_n_u16(u.val[0], 16)), + vreinterpretq_f32_u32(vshll_n_u16(u.val[1], 16)), + vreinterpretq_f32_u32(vshll_n_u16(u.val[2], 16))}; +} + +inline float32x4x3_t vld3q_bf16(const BFloat16 *ptr) { + return vld3q_bf16(reinterpret_cast(ptr)); +} + +inline float32x4x3_t vld3q(const uint16_t *ptr) { + return vld3q_bf16(ptr); +} + +inline float32x4x3_t vld3q(const BFloat16 *ptr) { + return vld3q_bf16(reinterpret_cast(ptr)); +} + +// store of 4D vector +inline void vst1q_bf16(uint16_t *ptr, const float32x4_t v) { + vst1_u16(ptr, vshrn_n_u32(vreinterpretq_u32_f32(v), 16)); +} + +inline void vst1q_bf16(BFloat16 *ptr, const float32x4_t v) { + vst1q_bf16(reinterpret_cast(ptr), v); +} + +inline void vst1q(uint16_t *ptr, const float32x4_t v) { + vst1q_bf16(ptr, v); +} + +inline void vst1q(BFloat16 *ptr, const float32x4_t v) { + vst1q_bf16(reinterpret_cast(ptr), v); +} + +// store of 2 4D vectors and perform interleaving +inline void vst2q_bf16(uint16_t *ptr, const float32x4x2_t v) { + uint16x4x2_t u = {vshrn_n_u32(vreinterpretq_u32_f32(v.val[0]), 16), + vshrn_n_u32(vreinterpretq_u32_f32(v.val[1]), 16)}; + vst2_u16(ptr, u); +} + +inline void vst2q_bf16(BFloat16 *ptr, const float32x4x2_t v) { + vst2q_bf16(reinterpret_cast(ptr), v); +} + +inline void vst2q(uint16_t *ptr, const float32x4x2_t v) { + vst2q_bf16(ptr, v); +} + +inline void vst2q(BFloat16 *ptr, const float32x4x2_t v) { + vst2q_bf16(reinterpret_cast(ptr), v); +} + +// store of 3 4D vectors and perform interleaving +inline void vst3q_bf16(uint16_t *ptr, const float32x4x3_t v) { + uint16x4x3_t u = {vshrn_n_u32(vreinterpretq_u32_f32(v.val[0]), 16), + vshrn_n_u32(vreinterpretq_u32_f32(v.val[0]), 16), + vshrn_n_u32(vreinterpretq_u32_f32(v.val[0]), 16)}; + vst3_u16(ptr, u); +} + +inline void vst3q_bf16(BFloat16 *ptr, const float32x4x3_t v) { + vst3q_bf16(reinterpret_cast(ptr), v); +} + +inline void vst3q(uint16_t *ptr, const float32x4x3_t v) { + vst3q_bf16(ptr, v); +} + +inline void vst3q(BFloat16 *ptr, const float32x4x3_t v) { + vst3q_bf16(reinterpret_cast(ptr), v); +} + +// load of 8D vector +inline float32x8_t vld1o_bf16(const uint16_t *ptr) { + uint16x8_t u = vld1q_u16(ptr); + return {vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(u), 16)), + vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(u), 16))}; +} + +inline float32x8_t vld1o_bf16(const BFloat16 *ptr) { + return vld1o_bf16(reinterpret_cast(ptr)); +} + +inline float32x8_t vld1o(const uint16_t *ptr) { + return vld1o_bf16(ptr); +} + +inline float32x8_t vld1o(const BFloat16 *ptr) { + return vld1o_bf16(reinterpret_cast(ptr)); +} + +// store of 8D vector +inline void vst1o_bf16(uint16_t *ptr, const float32x8_t v) { + vst1q_u16(ptr, vcombine_u16( + vshrn_n_u32(vreinterpretq_u32_f32(v.val[0]), 16), + vshrn_n_u32(vreinterpretq_u32_f32(v.val[1]), 16))); +} + +inline void vst1o_bf16(BFloat16 *ptr, const float32x8_t v) { + vst1o_bf16(reinterpret_cast(ptr), v); +} + +inline void vst1o(uint16_t *ptr, const float32x8_t v) { + vst1o_bf16(ptr, v); +} + +inline void vst1o(BFloat16 *ptr, const float32x8_t v) { + vst1o_bf16(reinterpret_cast(ptr), v); +} + +#endif // MACE_ENABLE_BFLOAT16 + +} // namespace arm +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARM_BASE_COMMON_NEON_H_ diff --git a/mace/ops/arm/base/conv_2d_1x1.cc b/mace/ops/arm/base/conv_2d_1x1.cc index 7fa96e8c..40ca524e 100644 --- a/mace/ops/arm/base/conv_2d_1x1.cc +++ b/mace/ops/arm/base/conv_2d_1x1.cc @@ -96,6 +96,11 @@ void RegisterConv2dK1x1Delegator(OpDelegatorRegistry *registry) { registry, Conv2dK1x1, delegator::Conv2dParam, MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, float, ImplType::NEON, K1x1)); + + MACE_REGISTER_BF16_DELEGATOR( + registry, Conv2dK1x1, delegator::Conv2dParam, + MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, + BFloat16, ImplType::NEON, K1x1)); } } // namespace arm diff --git a/mace/ops/arm/base/conv_2d_3x3.cc b/mace/ops/arm/base/conv_2d_3x3.cc index f2c02b3a..cf232a77 100644 --- a/mace/ops/arm/base/conv_2d_3x3.cc +++ b/mace/ops/arm/base/conv_2d_3x3.cc @@ -27,6 +27,15 @@ void RegisterConv2dK3x3Delegator(OpDelegatorRegistry *registry) { registry, Conv2dK3x3S2, delegator::Conv2dParam, MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, float, ImplType::NEON, K3x3S2)); + + MACE_REGISTER_BF16_DELEGATOR( + registry, Conv2dK3x3S1, delegator::Conv2dParam, + MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, + BFloat16, ImplType::NEON, K3x3S1)); + MACE_REGISTER_BF16_DELEGATOR( + registry, Conv2dK3x3S2, delegator::Conv2dParam, + MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, + BFloat16, ImplType::NEON, K3x3S2)); } } // namespace arm diff --git a/mace/ops/arm/fp32/conv_2d_3x3_winograd.cc b/mace/ops/arm/base/conv_2d_3x3_winograd.cc similarity index 83% rename from mace/ops/arm/fp32/conv_2d_3x3_winograd.cc rename to mace/ops/arm/base/conv_2d_3x3_winograd.cc index 051d5587..a66d0d62 100644 --- a/mace/ops/arm/fp32/conv_2d_3x3_winograd.cc +++ b/mace/ops/arm/base/conv_2d_3x3_winograd.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "mace/ops/arm/fp32/conv_2d_3x3_winograd.h" +#include "mace/ops/arm/base/conv_2d_3x3_winograd.h" #include +#include "mace/ops/arm/base/common_neon.h" #include "mace/ops/common/conv_pool_2d_util.h" #include "mace/ops/delegator/conv_2d.h" #include "mace/utils/math.h" @@ -24,12 +25,12 @@ namespace mace { namespace ops { namespace arm { -namespace fp32 { -MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, - const Tensor *input, - const Tensor *filter, - Tensor *output) { +template +MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, + const Tensor *input, + const Tensor *filter, + Tensor *output) { const index_t batch = input->dim(0); const index_t in_channels = input->dim(1); const index_t in_height = input->dim(2); @@ -84,17 +85,17 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, // pad input and transform input auto scratch_buffer = context->device()->scratch_buffer(); const index_t padded_in_size = is_in_padded ? PadAlignSize( - sizeof(float) * batch * in_channels * padded_in_height + sizeof(T) * batch * in_channels * padded_in_height * padded_in_width) : 0; const index_t padded_out_size = is_out_padded ? PadAlignSize( - sizeof(float) * batch * out_channels * padded_out_height + sizeof(T) * batch * out_channels * padded_out_height * padded_out_width) : 0; const index_t transformed_in_size = PadAlignSize( - sizeof(float) * batch * in_tile_area * in_channels * tile_count); + sizeof(T) * batch * in_tile_area * in_channels * tile_count); const index_t transformed_out_size = PadAlignSize( - sizeof(float) * batch * in_tile_area * out_channels * tile_count); + sizeof(T) * batch * in_tile_area * out_channels * tile_count); const index_t transformed_filter_size = - PadAlignSize(sizeof(float) * in_tile_area * out_channels * in_channels); + PadAlignSize(sizeof(T) * in_tile_area * out_channels * in_channels); const index_t gemm_pack_size = transformed_in_size + transformed_filter_size + transformed_filter_size; @@ -104,8 +105,8 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, + transformed_out_size + gemm_pack_size); const Tensor *padded_in = input; - Tensor tmp_padded_in - (scratch_buffer->Scratch(padded_in_size), DataType::DT_FLOAT); + Tensor tmp_padded_in(scratch_buffer->Scratch(padded_in_size), + DataTypeToEnum::value); if (is_in_padded) { tmp_padded_in.Resize({batch, in_channels, padded_in_height, padded_in_width}); @@ -115,8 +116,8 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, } Tensor *padded_out = output; - Tensor tmp_padded_out - (scratch_buffer->Scratch(padded_out_size), DataType::DT_FLOAT); + Tensor tmp_padded_out(scratch_buffer->Scratch(padded_out_size), + DataTypeToEnum::value); if (is_out_padded) { padded_out = &tmp_padded_out; padded_out->Resize({batch, out_channels, padded_out_height, @@ -125,17 +126,17 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, auto transformed_in = scratch_buffer->Scratch(transformed_in_size); auto transformed_out = scratch_buffer->Scratch(transformed_out_size); - auto padded_in_data = padded_in->data(); - auto padded_out_data = padded_out->mutable_data(); - auto transformed_in_data = transformed_in.mutable_data(); - auto transformed_out_data = transformed_out.mutable_data(); - auto filter_data = filter->data(); + auto padded_in_data = padded_in->data(); + auto padded_out_data = padded_out->mutable_data(); + auto transformed_in_data = transformed_in.mutable_data(); + auto transformed_out_data = transformed_out.mutable_data(); + auto filter_data = filter->data(); if (!filter->is_weight() || out_tile_size != out_tile_size_) { out_tile_size_ = out_tile_size; transformed_filter_.reset(new Tensor); transformed_filter_->Resize({in_tile_area, out_channels, in_channels}); - auto transformed_filter_data = transformed_filter_->mutable_data(); + auto transformed_filter_data = transformed_filter_->mutable_data(); switch (out_tile_size) { case 2: TransformFilter4x4(context, @@ -181,9 +182,9 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, const index_t scratch_buffer_offset = scratch_buffer->offset(); const index_t transformed_in_size_per_batch = - in_tile_area * in_channels * tile_count * sizeof(float); + in_tile_area * in_channels * tile_count * sizeof(T); const index_t transformed_out_size_per_batch = - in_tile_area * out_channels * tile_count * sizeof(float); + in_tile_area * out_channels * tile_count * sizeof(T); for (index_t b = 0; b < batch; ++b) { scratch_buffer->Rewind(scratch_buffer_offset); @@ -194,10 +195,11 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, b * transformed_out_size_per_batch, transformed_out_size_per_batch); - Tensor transformed_in_this_batch(transformed_in_slice, DataType::DT_FLOAT); + Tensor transformed_in_this_batch(transformed_in_slice, + DataTypeToEnum::value); transformed_in_this_batch.Resize({in_tile_area, in_channels, tile_count}); - Tensor - transformed_out_this_batch(transformed_out_slice, DataType::DT_FLOAT); + Tensor transformed_out_this_batch(transformed_out_slice, + DataTypeToEnum::value); transformed_out_this_batch.Resize({in_tile_area, out_channels, tile_count}); gemm_.Compute(context, @@ -246,11 +248,12 @@ MaceStatus Conv2dK3x3Winograd::Compute(const OpContext *context, } // OCHW => TOC -void Conv2dK3x3Winograd::TransformFilter4x4(const OpContext *context, - const float *filter, - const index_t in_channels, - const index_t out_channels, - float *output) { +template +void Conv2dK3x3Winograd::TransformFilter4x4(const OpContext *context, + const T *filter, + const index_t in_channels, + const index_t out_channels, + T *output) { const index_t stride = out_channels * in_channels; utils::ThreadPool @@ -339,11 +342,12 @@ void Conv2dK3x3Winograd::TransformFilter4x4(const OpContext *context, ⎢ ⎥ ⎣ 0 0 1 ⎦ */ -void Conv2dK3x3Winograd::TransformFilter8x8(const OpContext *context, - const float *filter, - const index_t in_channels, - const index_t out_channels, - float *output) { +template +void Conv2dK3x3Winograd::TransformFilter8x8(const OpContext *context, + const T *filter, + const index_t in_channels, + const index_t out_channels, + T *output) { const index_t stride = out_channels * in_channels; const float G[8][3] = {{1.0f, 0.0f, 0.0f}, @@ -396,14 +400,15 @@ void Conv2dK3x3Winograd::TransformFilter8x8(const OpContext *context, } // NCHW => NTCB (T: in tile pixels, B: tile indices) -void Conv2dK3x3Winograd::TransformInput4x4(const OpContext *context, - const float *input, - const index_t batch, - const index_t in_height, - const index_t in_width, - const index_t in_channels, - const index_t tile_count, - float *output) { +template +void Conv2dK3x3Winograd::TransformInput4x4(const OpContext *context, + const T *input, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t tile_count, + T *output) { const index_t stride = in_channels * tile_count; const index_t in_height_width = in_height * in_width; const index_t input_batch_size = in_height_width * in_channels; @@ -420,14 +425,12 @@ void Conv2dK3x3Winograd::TransformInput4x4(const OpContext *context, for (index_t h = 0; h < in_height - 2; h += 2) { for (index_t w = 0; w < in_width - 2; w += 2) { float d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, - d14, - d15; + d14, d15; float s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, - s14, - s15; + s14, s15; // load tile data - const float *input_ptr = input + n * input_batch_size + + const T *input_ptr = input + n * input_batch_size + c * in_height_width + h * in_width + w; d0 = input_ptr[0]; d1 = input_ptr[1]; @@ -468,7 +471,7 @@ void Conv2dK3x3Winograd::TransformInput4x4(const OpContext *context, s15 = (d5 - d13) - (d7 - d15); // store output - float *output_ptr = + T *output_ptr = output + n * output_batch_size + c * tile_count + tile_index; output_ptr[0] = s0; output_ptr[1 * stride] = s1; @@ -517,14 +520,15 @@ void Conv2dK3x3Winograd::TransformInput4x4(const OpContext *context, ⎢ ⎥ ⎣0 -1 0 21/4 0 -21/4 0 1⎦ */ -void Conv2dK3x3Winograd::TransformInput8x8(const OpContext *context, - const float *input, - const index_t batch, - const index_t in_height, - const index_t in_width, - const index_t in_channels, - const index_t tile_count, - float *output) { +template +void Conv2dK3x3Winograd::TransformInput8x8(const OpContext *context, + const T *input, + const index_t batch, + const index_t in_height, + const index_t in_width, + const index_t in_channels, + const index_t tile_count, + T *output) { const index_t stride = in_channels * tile_count; const index_t in_height_width = in_height * in_width; const index_t input_batch_size = in_height_width * in_channels; @@ -540,7 +544,7 @@ void Conv2dK3x3Winograd::TransformInput8x8(const OpContext *context, float s[8][8]; for (index_t h = 0; h < in_height - 2; h += 6) { for (index_t w = 0; w < in_width - 2; w += 6) { - const float *input_ptr = input + n * input_batch_size + + const T *input_ptr = input + n * input_batch_size + c * in_height_width + h * in_width + w; for (int i = 0; i < 8; ++i) { @@ -575,7 +579,7 @@ void Conv2dK3x3Winograd::TransformInput8x8(const OpContext *context, input_ptr += in_width; } - float *output_ptr = + T *output_ptr = output + n * output_batch_size + c * tile_count + tile_index; for (int i = 0; i < 8; ++i) { float d0, d1, d2, d3, d4, d5, d6, d7; @@ -616,14 +620,15 @@ void Conv2dK3x3Winograd::TransformInput8x8(const OpContext *context, } // NTOB => NToOB => NOHoWo -void Conv2dK3x3Winograd::TransformOutput4x4(const OpContext *context, - const float *input, - index_t batch, - index_t out_height, - index_t out_width, - index_t out_channels, - index_t tile_count, - float *output) { +template +void Conv2dK3x3Winograd::TransformOutput4x4(const OpContext *context, + const T *input, + index_t batch, + index_t out_height, + index_t out_width, + index_t out_channels, + index_t tile_count, + T *output) { const index_t stride = out_channels * tile_count; const index_t input_batch_size = 16 * stride; const index_t out_image_size = out_height * out_width; @@ -644,7 +649,7 @@ void Conv2dK3x3Winograd::TransformOutput4x4(const OpContext *context, float s0, s1, s2, s3, s4, s5, s6, s7; float v0, v1, v2, v3; - const float *input_ptr = + const T *input_ptr = input + n * input_batch_size + m * tile_count + tile_offset; d0 = input_ptr[0]; d1 = input_ptr[1 * stride]; @@ -680,7 +685,7 @@ void Conv2dK3x3Winograd::TransformOutput4x4(const OpContext *context, v2 = s2 - s4 - s6; v3 = s3 - s5 - s7; - float *output_ptr = output + n * output_batch_size + + T *output_ptr = output + n * output_batch_size + m * out_image_size + h * out_width + w; output_ptr[0] = v0; output_ptr[1] = v1; @@ -710,14 +715,15 @@ void Conv2dK3x3Winograd::TransformOutput4x4(const OpContext *context, ⎢ ⎥ ⎣0 1 -1 32 -32 1 -1 1⎦ */ -void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context, - const float *input, - index_t batch, - index_t out_height, - index_t out_width, - index_t out_channels, - index_t tile_count, - float *output) { +template +void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context, + const T *input, + index_t batch, + index_t out_height, + index_t out_width, + index_t out_channels, + index_t tile_count, + T *output) { const index_t stride = out_channels * tile_count; const index_t input_batch_size = 64 * stride; const index_t out_image_size = out_height * out_width; @@ -733,7 +739,7 @@ void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context, float s[8][6]; for (index_t h = 0; h < out_height; h += 6) { for (index_t w = 0; w < out_width; w += 6) { - const float *input_ptr = + const T *input_ptr = input + n * input_batch_size + m * tile_count + tile_offset; for (int i = 0; i < 8; ++i) { float d0, d1, d2, d3, d4, d5, d6, d7; @@ -764,7 +770,7 @@ void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context, input_ptr += 8 * stride; } - float *output_ptr = output + n * output_batch_size + + T *output_ptr = output + n * output_batch_size + m * out_image_size + h * out_width + w; for (int i = 0; i < 6; ++i) { @@ -803,12 +809,16 @@ void Conv2dK3x3Winograd::TransformOutput8x8(const OpContext *context, void RegisterConv2dK3x3WinogradDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( - registry, Conv2dK3x3Winograd, delegator::Conv2dParam, + registry, Conv2dK3x3Winograd, delegator::Conv2dParam, MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, float, ImplType::NEON, K3x3Winograd)); + + MACE_REGISTER_BF16_DELEGATOR( + registry, Conv2dK3x3Winograd, delegator::Conv2dParam, + MACE_DELEGATOR_KEY_EX(Conv2d, DeviceType::CPU, + BFloat16, ImplType::NEON, K3x3Winograd)); } -} // namespace fp32 } // namespace arm } // namespace ops } // namespace mace diff --git a/mace/ops/arm/fp32/conv_2d_3x3_winograd.h b/mace/ops/arm/base/conv_2d_3x3_winograd.h similarity index 78% rename from mace/ops/arm/fp32/conv_2d_3x3_winograd.h rename to mace/ops/arm/base/conv_2d_3x3_winograd.h index 513cc99d..f06d8d4d 100644 --- a/mace/ops/arm/fp32/conv_2d_3x3_winograd.h +++ b/mace/ops/arm/base/conv_2d_3x3_winograd.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_ARM_FP32_CONV_2D_3X3_WINOGRAD_H_ -#define MACE_OPS_ARM_FP32_CONV_2D_3X3_WINOGRAD_H_ +#ifndef MACE_OPS_ARM_BASE_CONV_2D_3X3_WINOGRAD_H_ +#define MACE_OPS_ARM_BASE_CONV_2D_3X3_WINOGRAD_H_ #include #include @@ -27,12 +27,12 @@ namespace mace { namespace ops { namespace arm { -namespace fp32 { +template class Conv2dK3x3Winograd : public Conv2dBase { public: explicit Conv2dK3x3Winograd(const delegator::Conv2dParam ¶m) - : Conv2dBase(param, sizeof(float)), + : Conv2dBase(param, sizeof(T)), gemm_(delegator::GemmParam()), transformed_filter_(nullptr), out_tile_size_(0) {} @@ -47,61 +47,60 @@ class Conv2dK3x3Winograd : public Conv2dBase { private: void TransformFilter4x4(const OpContext *context, - const float *filter, + const T *filter, const index_t in_channels, const index_t out_channels, - float *output); + T *output); void TransformFilter8x8(const OpContext *context, - const float *filter, + const T *filter, const index_t in_channels, const index_t out_channels, - float *output); + T *output); void TransformInput4x4(const OpContext *context, - const float *input, + const T *input, const index_t batch, const index_t in_height, const index_t in_width, const index_t in_channels, const index_t tile_count, - float *output); + T *output); void TransformInput8x8(const OpContext *context, - const float *input, + const T *input, const index_t batch, const index_t in_height, const index_t in_width, const index_t in_channels, const index_t tile_count, - float *output); + T *output); void TransformOutput4x4(const OpContext *context, - const float *input, + const T *input, index_t batch, index_t out_height, index_t out_width, index_t out_channels, index_t tile_count, - float *output); + T *output); void TransformOutput8x8(const OpContext *context, - const float *input, + const T *input, index_t batch, index_t out_height, index_t out_width, index_t out_channels, index_t tile_count, - float *output); + T *output); - Gemm gemm_; + Gemm gemm_; std::unique_ptr transformed_filter_; index_t out_tile_size_; }; -} // namespace fp32 } // namespace arm } // namespace ops } // namespace mace -#endif // MACE_OPS_ARM_FP32_CONV_2D_3X3_WINOGRAD_H_ +#endif // MACE_OPS_ARM_BASE_CONV_2D_3X3_WINOGRAD_H_ diff --git a/mace/ops/arm/base/conv_2d_general.cc b/mace/ops/arm/base/conv_2d_general.cc index 04121b8c..52fa6aa9 100644 --- a/mace/ops/arm/base/conv_2d_general.cc +++ b/mace/ops/arm/base/conv_2d_general.cc @@ -16,6 +16,8 @@ #include +#include "mace/ops/arm/base/common_neon.h" + namespace mace { namespace ops { namespace arm { @@ -57,10 +59,155 @@ MaceStatus Conv2dGeneral::Compute(const OpContext *context, return MaceStatus::MACE_SUCCESS; } +template +MaceStatus Conv2dGeneral::DoCompute( + const ConvComputeParam &p, const T *filter_data, + const T *input_data, T *output_data, + const std::vector &filter_shape) { + const index_t filter_height = filter_shape[2]; + const index_t filter_width = filter_shape[3]; + const index_t filter_size = filter_height * filter_width; + + p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + for (index_t m = start1; m < end1; m += step1) { + const int stride_h = strides_[0]; + const int stride_w = strides_[1]; + const int dilation_h = dilations_[0]; + const int dilation_w = dilations_[1]; + if (m + 3 < p.out_channels) { + T *out_ptr0_base = + output_data + b * p.out_batch_size + m * p.out_image_size; + T *out_ptr1_base = out_ptr0_base + p.out_image_size; + T *out_ptr2_base = out_ptr1_base + p.out_image_size; + T *out_ptr3_base = out_ptr2_base + p.out_image_size; + for (index_t h = 0; h < p.out_height; ++h) { + const index_t ih = h * stride_h; + for (index_t w = 0; w + 3 < p.out_width; w += 4) { + const index_t iw = w * stride_w; + index_t out_offset = h * p.out_width + w; + float32x4_t vo0 = vdupq_n_f32(0.f); + float32x4_t vo1 = vdupq_n_f32(0.f); + float32x4_t vo2 = vdupq_n_f32(0.f); + float32x4_t vo3 = vdupq_n_f32(0.f); + const T *in_ptr_base = input_data + b * p.in_batch_size; + const T *filter_ptr0 = + filter_data + m * p.in_channels * filter_size; + const T *filter_ptr1 = filter_ptr0 + p.in_channels * filter_size; + const T *filter_ptr2 = filter_ptr1 + p.in_channels * filter_size; + const T *filter_ptr3 = filter_ptr2 + p.in_channels * filter_size; + for (index_t c = 0; c < p.in_channels; ++c) { + index_t in_offset = ih * p.in_width + iw; + // calc by row + for (index_t kh = 0; kh < filter_height; ++kh) { + for (index_t kw = 0; kw < filter_width; ++kw) { + const T i0 = in_ptr_base[in_offset + kw * dilation_w]; + const T i1 = + in_ptr_base[in_offset + stride_w + kw * dilation_w]; + const T i2 = + in_ptr_base[in_offset + 2 * stride_w + kw * dilation_w]; + const T i3 = + in_ptr_base[in_offset + 3 * stride_w + kw * dilation_w]; + const T f0 = filter_ptr0[kw]; + const T f1 = filter_ptr1[kw]; + const T f2 = filter_ptr2[kw]; + const T f3 = filter_ptr3[kw]; + // outch 0 + vo0[0] += i0 * f0; + vo0[1] += i1 * f0; + vo0[2] += i2 * f0; + vo0[3] += i3 * f0; + // outch 1 + vo1[0] += i0 * f1; + vo1[1] += i1 * f1; + vo1[2] += i2 * f1; + vo1[3] += i3 * f1; + // outch 2 + vo2[0] += i0 * f2; + vo2[1] += i1 * f2; + vo2[2] += i2 * f2; + vo2[3] += i3 * f2; + // outch 3 + vo3[0] += i0 * f3; + vo3[1] += i1 * f3; + vo3[2] += i2 * f3; + vo3[3] += i3 * f3; + } // kw + + in_offset += dilation_h * p.in_width; + filter_ptr0 += filter_width; + filter_ptr1 += filter_width; + filter_ptr2 += filter_width; + filter_ptr3 += filter_width; + } // kh + in_ptr_base += p.in_image_size; + } // c + vst1q(out_ptr0_base + out_offset, vo0); + vst1q(out_ptr1_base + out_offset, vo1); + vst1q(out_ptr2_base + out_offset, vo2); + vst1q(out_ptr3_base + out_offset, vo3); + } // w + } // h + } else { + for (index_t mm = m; mm < p.out_channels; ++mm) { + T *out_ptr0_base = + output_data + b * p.out_batch_size + mm * p.out_image_size; + for (index_t h = 0; h < p.out_height; ++h) { + for (index_t w = 0; w + 3 < p.out_width; w += 4) { + // input offset + const index_t ih = h * stride_h; + const index_t iw = w * stride_w; + // output offset + const index_t out_offset = h * p.out_width + w; + // output (1 outch x 1 height x 4 width): vo_outch_height + float32x4_t vo0 = vdupq_n_f32(0.f); + const T *in_ptr_base = input_data + b * p.in_batch_size; + const T *filter_ptr0 = + filter_data + mm * p.in_channels * filter_size; + for (index_t c = 0; c < p.in_channels; ++c) { + index_t in_offset = ih * p.in_width + iw; + for (index_t kh = 0; kh < filter_height; ++kh) { + for (index_t kw = 0; kw < filter_width; ++kw) { + T i0 = in_ptr_base[in_offset + kw * dilation_w]; + T i1 = in_ptr_base[in_offset + stride_w + + kw * dilation_w]; + T i2 = in_ptr_base[in_offset + 2 * stride_w + + kw * dilation_w]; + T i3 = in_ptr_base[in_offset + 3 * stride_w + + kw * dilation_w]; + T f0 = filter_ptr0[kw]; + // outch 0 + vo0[0] += i0 * f0; + vo0[1] += i1 * f0; + vo0[2] += i2 * f0; + vo0[3] += i3 * f0; + } // kw + in_offset += dilation_h * p.in_width; + filter_ptr0 += filter_width; + } // kh + in_ptr_base += p.in_image_size; + } // c + vst1q(out_ptr0_base + out_offset, vo0); + } // w + } // h + } // mm + } // if + } // m + } // b + }, 0, p.batch, 1, 0, p.out_channels, 4); + + return MaceStatus::MACE_SUCCESS; +} + void RegisterConv2dGeneralDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, Conv2dGeneral, delegator::Conv2dParam, MACE_DELEGATOR_KEY(Conv2d, DeviceType::CPU, float, ImplType::NEON)); + + MACE_REGISTER_BF16_DELEGATOR( + registry, Conv2dGeneral, delegator::Conv2dParam, + MACE_DELEGATOR_KEY(Conv2d, DeviceType::CPU, BFloat16, ImplType::NEON)); } } // namespace arm diff --git a/mace/ops/arm/base/depthwise_conv_2d_3x3.cc b/mace/ops/arm/base/depthwise_conv_2d_3x3.cc index f9423959..e17245e6 100644 --- a/mace/ops/arm/base/depthwise_conv_2d_3x3.cc +++ b/mace/ops/arm/base/depthwise_conv_2d_3x3.cc @@ -14,10 +14,436 @@ #include "mace/ops/arm/base/depthwise_conv_2d_3x3.h" +#include "mace/ops/arm/base/common_neon.h" + namespace mace { namespace ops { namespace arm { +namespace { +template +void DepthwiseConv2d3x3Pixel(const T *in_base, + const T *filter, + const index_t out_h, + const index_t out_w, + const index_t in_h_start, + const index_t in_w_start, + const index_t out_width, + const index_t in_height, + const index_t in_width, + T *out_base) { + const index_t filter_width = 3; + float sum = 0.0f; + + index_t in_h = in_h_start; + const T *in = in_base + in_h * in_width; + const T *filter_ptr = filter; + if (in_h >= 0 && in_h < in_height) { + index_t in_w = in_w_start; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[0]; + } + in_w++; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[1]; + } + in_w++; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[2]; + } + } + in_h++; + in += in_width; + filter_ptr += filter_width; + if (in_h >= 0 && in_h < in_height) { + index_t in_w = in_w_start; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[0]; + } + in_w++; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[1]; + } + in_w++; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[2]; + } + } + in_h++; + in += in_width; + filter_ptr += filter_width; + if (in_h >= 0 && in_h < in_height) { + index_t in_w = in_w_start; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[0]; + } + in_w++; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[1]; + } + in_w++; + if (in_w >= 0 && in_w < in_width) { + sum += in[in_w] * filter_ptr[2]; + } + } + out_base[out_h * out_width + out_w] = sum; +} +} // namespace + +template +MaceStatus DepthwiseConv2dK3x3S1::DoCompute( + const DepthwiseConvComputeParam &p, const T *filter_data, + const T *input_data, T *output_data) { + p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + for (index_t m = start1; m < end1; m += step1) { + const index_t c = m / p.multiplier; + const index_t multi_index = m % p.multiplier; + auto filter_ptr = filter_data + multi_index * p.in_channels * 9 + c * 9; + auto in_base = input_data + b * p.in_batch_size + c * p.in_image_size; + auto out_base = output_data + b * p.out_batch_size + + m * p.out_image_size; + index_t h, w; + + // top + for (h = 0; h < p.valid_h_start; ++h) { + for (w = 0; w < p.out_width; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h - p.pad_top, + w - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + } + + // load filter (1 outch x 3 height x 3 width): vf_outch_height + float32x4_t vf00, vf01, vf02; + vf00 = vld1q(filter_ptr); + vf01 = vld1q(filter_ptr + 3); + vf02 = vld1q(filter_ptr + 5); + + for (h = p.valid_h_start; h + 1 < p.valid_h_stop; h += 2) { + // left + for (w = 0; w < p.valid_w_start; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h - p.pad_top, + w - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h + 1, + w, + h + 1 - p.pad_top, + w - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + + for (w = p.valid_w_start; w + 3 < p.valid_w_stop; w += 4) { + // input (4 height x 3 slide): vi_height_slide + float32x4_t vi00, vi01, vi02, vi0n; + float32x4_t vi10, vi11, vi12, vi1n; + float32x4_t vi20, vi21, vi22, vi2n; + float32x4_t vi30, vi31, vi32, vi3n; + + // output (1 outch x 2 height x 4 width): vo_outch_height + float32x4_t vo00, vo01; + + // load input + index_t in_h = h - p.pad_top; + index_t in_w = w - p.pad_left; + index_t in_offset = in_h * p.in_width + in_w; + vi00 = vld1q(in_base + in_offset); + vi0n = vld1q(in_base + in_offset + 4); + vi10 = vld1q(in_base + in_offset + p.in_width); + vi1n = vld1q(in_base + in_offset + p.in_width + 4); + vi20 = vld1q(in_base + in_offset + 2 * p.in_width); + vi2n = vld1q(in_base + in_offset + 2 * p.in_width + 4); + vi30 = vld1q(in_base + in_offset + 3 * p.in_width); + vi3n = vld1q(in_base + in_offset + 3 * p.in_width + 4); + + vi01 = vextq_f32(vi00, vi0n, 1); + vi02 = vextq_f32(vi00, vi0n, 2); + vi11 = vextq_f32(vi10, vi1n, 1); + vi12 = vextq_f32(vi10, vi1n, 2); + vi21 = vextq_f32(vi20, vi2n, 1); + vi22 = vextq_f32(vi20, vi2n, 2); + vi31 = vextq_f32(vi30, vi3n, 1); + vi32 = vextq_f32(vi30, vi3n, 2); + + // load ouptut + index_t out_offset = h * p.out_width + w; + vo00 = vld1q(out_base + out_offset); + vo01 = vld1q(out_base + out_offset + p.out_width); + +#if defined(__aarch64__) + // outch 0, height 0 + vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0); + vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1); + vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2); + vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0); + vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1); + vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2); + vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 1); + vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 2); + vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 3); + + // outch 0, height 1 + vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0); + vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1); + vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2); + vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0); + vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1); + vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2); + vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 1); + vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 2); + vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 3); +#else + // outch 0, height 0 + vo00 = vmlaq_lane_f32(vo00, vi00, vget_low_f32(vf00), 0); + vo00 = vmlaq_lane_f32(vo00, vi01, vget_low_f32(vf00), 1); + vo00 = vmlaq_lane_f32(vo00, vi02, vget_high_f32(vf00), 0); + vo00 = vmlaq_lane_f32(vo00, vi10, vget_low_f32(vf01), 0); + vo00 = vmlaq_lane_f32(vo00, vi11, vget_low_f32(vf01), 1); + vo00 = vmlaq_lane_f32(vo00, vi12, vget_high_f32(vf01), 0); + vo00 = vmlaq_lane_f32(vo00, vi20, vget_low_f32(vf02), 1); + vo00 = vmlaq_lane_f32(vo00, vi21, vget_high_f32(vf02), 0); + vo00 = vmlaq_lane_f32(vo00, vi22, vget_high_f32(vf02), 1); + + // outch 0, height 1 + vo01 = vmlaq_lane_f32(vo01, vi10, vget_low_f32(vf00), 0); + vo01 = vmlaq_lane_f32(vo01, vi11, vget_low_f32(vf00), 1); + vo01 = vmlaq_lane_f32(vo01, vi12, vget_high_f32(vf00), 0); + vo01 = vmlaq_lane_f32(vo01, vi20, vget_low_f32(vf01), 0); + vo01 = vmlaq_lane_f32(vo01, vi21, vget_low_f32(vf01), 1); + vo01 = vmlaq_lane_f32(vo01, vi22, vget_high_f32(vf01), 0); + vo01 = vmlaq_lane_f32(vo01, vi30, vget_low_f32(vf02), 1); + vo01 = vmlaq_lane_f32(vo01, vi31, vget_high_f32(vf02), 0); + vo01 = vmlaq_lane_f32(vo01, vi32, vget_high_f32(vf02), 1); +#endif + vst1q(out_base + out_offset, vo00); + vst1q(out_base + out_offset + p.out_width, vo01); + } // w + + // right + for (; w < p.out_width; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h - p.pad_top, + w - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h + 1, + w, + h + 1 - p.pad_top, + w - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + } // h + + // bottom + for (; h < p.out_height; ++h) { + for (w = 0; w < p.out_width; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h - p.pad_top, + w - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + } + } // m + } // b + }, 0, p.batch, 1, 0, p.out_channels, 1); // threadpool + + return MaceStatus::MACE_SUCCESS; +} + +template +MaceStatus DepthwiseConv2dK3x3S2::DoCompute( + const DepthwiseConvComputeParam &p, const T *filter_data, + const T *input_data, T *output_data) { + p.thread_pool.Compute2D( + [=](index_t start0, index_t end0, index_t step0, index_t start1, + index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + for (index_t m = start1; m < end1; m += step1) { + index_t c = m / p.multiplier; + index_t multi_index = m % p.multiplier; + auto filter_ptr = filter_data + multi_index * p.in_channels * 9 + + c * 9; + auto in_base = input_data + b * p.in_batch_size + + c * p.in_image_size; + auto out_base = output_data + b * p.out_batch_size + + m * p.out_image_size; + index_t h, w; + + // top + for (h = 0; h < p.valid_h_start; ++h) { + for (w = 0; w < p.out_width; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h * 2 - p.pad_top, + w * 2 - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + } + + // load filter (1 outch x 3 height x 3 width): vf_outch_height + float32x4_t vf00, vf01, vf02; + vf00 = vld1q(filter_ptr); + vf01 = vld1q(filter_ptr + 3); + vf02 = vld1q(filter_ptr + 5); + + for (h = p.valid_h_start; h < p.valid_h_stop; ++h) { + // left + for (w = 0; w < p.valid_w_start; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h * 2 - p.pad_top, + w * 2 - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + + for (w = p.valid_w_start; w + 3 < p.valid_w_stop; w += 4) { + float32x4x2_t vi0, vi1, vi2; + float32x4_t vi0n, vi1n, vi2n; + + // input (3 height x 3 slide): vi_height_slide + float32x4_t vi00, vi01, vi02; + float32x4_t vi10, vi11, vi12; + float32x4_t vi20, vi21, vi22; + + // output (1 outch x 1 height x 4 width): vo + float32x4_t vo; + + // load input + index_t in_h = h * 2 - p.pad_top; + index_t in_w = w * 2 - p.pad_left; + index_t in_offset = in_h * p.in_width + in_w; + vi0 = vld2q(in_base + in_offset); // [0.2.4.6, 1.3.5.7] + vi1 = vld2q(in_base + in_offset + p.in_width); + vi2 = vld2q(in_base + in_offset + 2 * p.in_width); + + vi0n = vld1q(in_base + in_offset + 8); // [8.9.10.11] + vi1n = vld1q(in_base + in_offset + p.in_width + 8); + vi2n = vld1q(in_base + in_offset + 2 * p.in_width + 8); + + // load ouptut + index_t out_offset = h * p.out_width + w; + vo = vld1q(out_base + out_offset); + + vi00 = vi0.val[0]; // [0.2.4.6] + vi01 = vi0.val[1]; // [1.3.5.7] + vi02 = vextq_f32(vi00, vi0n, 1); // [2.4.6.8] + vi10 = vi1.val[0]; + vi11 = vi1.val[1]; + vi12 = vextq_f32(vi10, vi1n, 1); + vi20 = vi2.val[0]; + vi21 = vi2.val[1]; + vi22 = vextq_f32(vi20, vi2n, 1); + +#if defined(__aarch64__) + // outch 0, height 0 + vo = vfmaq_laneq_f32(vo, vi00, vf00, 0); + vo = vfmaq_laneq_f32(vo, vi01, vf00, 1); + vo = vfmaq_laneq_f32(vo, vi02, vf00, 2); + vo = vfmaq_laneq_f32(vo, vi10, vf01, 0); + vo = vfmaq_laneq_f32(vo, vi11, vf01, 1); + vo = vfmaq_laneq_f32(vo, vi12, vf01, 2); + vo = vfmaq_laneq_f32(vo, vi20, vf02, 1); + vo = vfmaq_laneq_f32(vo, vi21, vf02, 2); + vo = vfmaq_laneq_f32(vo, vi22, vf02, 3); +#else + // outch 0, height 0 + vo = vmlaq_lane_f32(vo, vi00, vget_low_f32(vf00), 0); + vo = vmlaq_lane_f32(vo, vi01, vget_low_f32(vf00), 1); + vo = vmlaq_lane_f32(vo, vi02, vget_high_f32(vf00), 0); + vo = vmlaq_lane_f32(vo, vi10, vget_low_f32(vf01), 0); + vo = vmlaq_lane_f32(vo, vi11, vget_low_f32(vf01), 1); + vo = vmlaq_lane_f32(vo, vi12, vget_high_f32(vf01), 0); + vo = vmlaq_lane_f32(vo, vi20, vget_low_f32(vf02), 1); + vo = vmlaq_lane_f32(vo, vi21, vget_high_f32(vf02), 0); + vo = vmlaq_lane_f32(vo, vi22, vget_high_f32(vf02), 1); +#endif + vst1q(out_base + out_offset, vo); + } // w + + // right + for (; w < p.out_width; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h * 2 - p.pad_top, + w * 2 - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + } // h + + // bottom + for (; h < p.out_height; ++h) { + for (w = 0; w < p.out_width; ++w) { + DepthwiseConv2d3x3Pixel(in_base, + filter_ptr, + h, + w, + h * 2 - p.pad_top, + w * 2 - p.pad_left, + p.out_width, + p.in_height, + p.in_width, + out_base); + } + } + } // m + } // b + }, + 0, p.batch, 1, 0, p.out_channels, 1); + + return MaceStatus::MACE_SUCCESS; +} + void RegisterDepthwiseConv2dK3x3Delegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, DepthwiseConv2dK3x3S1, delegator::DepthwiseConv2dParam, @@ -27,6 +453,17 @@ void RegisterDepthwiseConv2dK3x3Delegator(OpDelegatorRegistry *registry) { registry, DepthwiseConv2dK3x3S2, delegator::DepthwiseConv2dParam, MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, float, ImplType::NEON, K3x3S2)); + + MACE_REGISTER_BF16_DELEGATOR( + registry, DepthwiseConv2dK3x3S1, + delegator::DepthwiseConv2dParam, + MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, + BFloat16, ImplType::NEON, K3x3S1)); + MACE_REGISTER_BF16_DELEGATOR( + registry, DepthwiseConv2dK3x3S2, + delegator::DepthwiseConv2dParam, + MACE_DELEGATOR_KEY_EX(DepthwiseConv2d, DeviceType::CPU, + BFloat16, ImplType::NEON, K3x3S2)); } } // namespace arm diff --git a/mace/ops/arm/base/gemm.cc b/mace/ops/arm/base/gemm.cc index 437f767e..6274311a 100644 --- a/mace/ops/arm/base/gemm.cc +++ b/mace/ops/arm/base/gemm.cc @@ -14,14 +14,693 @@ #include "mace/ops/arm/base/gemm.h" +#include +#include + +#include "mace/ops/arm/base/common_neon.h" + namespace mace { namespace ops { namespace arm { +template +void Gemm::Pack4x4(const MatrixMap &matrix, + MatrixMajor dst_major, T *packed_matrix) { + const index_t rows = matrix.rows(); + const index_t cols = matrix.cols(); + + // use the same terminology as GemmLowp: + // depth is depth, width is the opposite dim other than depth + // lhs + index_t width = rows; + index_t depth = cols; + index_t width_stride = matrix.rows_stride(); + index_t depth_stride = matrix.cols_stride(); + if (dst_major == RowMajor) { + // rhs + std::swap(width, depth); + std::swap(width_stride, depth_stride); + } + const T *data = matrix.data(); + T *packed_ptr = packed_matrix; + + const index_t block_size = 4; + const index_t depth_padded = RoundUp(depth, static_cast(4)); + + if (depth_padded > depth) { + memset(packed_ptr + depth * block_size, + 0, + sizeof(T) * (depth_padded - depth) * block_size); + } + + if (dst_major == matrix.matrix_major()) { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + memcpy(packed_ptr, data, sizeof(T) * width); + memset(packed_ptr + width, 0, sizeof(T) * width_remain); + data += depth_stride; + packed_ptr += block_size; + } + } else { + for (index_t d = 0; d < depth; ++d) { + float32x4_t vi = vld1q(data); + vst1q(packed_ptr, vi); + data += depth_stride; + packed_ptr += block_size; + } + } + } else { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + for (index_t w = 0; w < width; ++w) { + packed_ptr[w] = data[w * width_stride + d]; + } // w + memset(packed_ptr + width, 0, sizeof(T) * width_remain); + packed_ptr += block_size; + } // d + } else { + const T *data0 = data; + const T *data1 = data + width_stride; + const T *data2 = data1 + width_stride; + const T *data3 = data2 + width_stride; + + const index_t depth_block = depth / 4; + const index_t depth_remain = depth - depth_block * 4; + for (index_t depth_block_idx = 0; depth_block_idx < depth_block; + ++depth_block_idx) { + float32x4_t v0 = vld1q(data0); + float32x4_t v1 = vld1q(data1); + float32x4_t v2 = vld1q(data2); + float32x4_t v3 = vld1q(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + vst1q(packed_ptr, v0123_intertwined.val[0]); + packed_ptr += 4; + + vst1q(packed_ptr, v0123_intertwined.val[1]); + packed_ptr += 4; + + vst1q(packed_ptr, v0123n_intertwined.val[0]); + packed_ptr += 4; + + vst1q(packed_ptr, v0123n_intertwined.val[1]); + packed_ptr += 4; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + } + for (index_t d = 0; d < depth_remain; ++d) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q(packed_ptr, vi); + packed_ptr += 4; + + ++data0; + ++data1; + ++data2; + ++data3; + } // d + } + } +} + +template +void Gemm::Pack8x4(const MatrixMap &matrix, + MatrixMajor dst_major, T *packed_matrix) { + const index_t rows = matrix.rows(); + const index_t cols = matrix.cols(); + + // use the same terminology as GemmLowp: + // depth is depth, width is the opposite dim other than depth + // lhs + index_t width = rows; + index_t depth = cols; + index_t width_stride = matrix.rows_stride(); + index_t depth_stride = matrix.cols_stride(); + if (dst_major == RowMajor) { + // rhs + std::swap(width, depth); + std::swap(width_stride, depth_stride); + } + const T *data = matrix.data(); + T *packed_ptr = packed_matrix; + + const index_t block_size = 8; + const index_t depth_padded = RoundUp(depth, static_cast(4)); + + if (depth_padded > depth) { + memset(packed_ptr + depth * block_size, + 0, + sizeof(T) * (depth_padded - depth) * block_size); + } + + if (dst_major == matrix.matrix_major()) { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + memcpy(packed_ptr, data, sizeof(T) * width); + memset(packed_ptr + width, 0, sizeof(T) * width_remain); + data += depth_stride; + packed_ptr += block_size; + } + } else { + for (index_t d = 0; d < depth; ++d) { + float32x4_t vi = vld1q(data); + vst1q(packed_ptr, vi); + float32x4_t vin = vld1q(data + 4); + vst1q(packed_ptr + 4, vin); + data += depth_stride; + packed_ptr += block_size; + } + } + } else { + if (width < block_size) { + const index_t width_remain = block_size - width; + for (index_t d = 0; d < depth; ++d) { + for (index_t w = 0; w < width; ++w) { + packed_ptr[w] = data[w * width_stride + d]; + } // w + memset(packed_ptr + width, 0, sizeof(T) * width_remain); + packed_ptr += block_size; + } // d + } else { + const T *data0 = data; + const T *data1 = data + width_stride; + const T *data2 = data1 + width_stride; + const T *data3 = data2 + width_stride; + const T *data4 = data3 + width_stride; + const T *data5 = data4 + width_stride; + const T *data6 = data5 + width_stride; + const T *data7 = data6 + width_stride; + + const index_t depth_block = depth / 4; + const index_t depth_remain = depth - depth_block * 4; + for (index_t depth_block_idx = 0; depth_block_idx < depth_block; + ++depth_block_idx) { + float32x4_t v0 = vld1q(data0); + float32x4_t v1 = vld1q(data1); + float32x4_t v2 = vld1q(data2); + float32x4_t v3 = vld1q(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + float32x4_t v4 = vld1q(data4); + float32x4_t v5 = vld1q(data5); + float32x4_t v6 = vld1q(data6); + float32x4_t v7 = vld1q(data7); + float32x4x2_t v46_intertwined = vzipq_f32(v4, v6); + float32x4x2_t v57_intertwined = vzipq_f32(v5, v7); + float32x4x2_t v4567_intertwined = + vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]); + float32x4x2_t v4567n_intertwined = + vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]); + + vst1q(packed_ptr, v0123_intertwined.val[0]); + packed_ptr += 4; + + vst1q(packed_ptr, v4567_intertwined.val[0]); + packed_ptr += 4; + + vst1q(packed_ptr, v0123_intertwined.val[1]); + packed_ptr += 4; + + vst1q(packed_ptr, v4567_intertwined.val[1]); + packed_ptr += 4; + + vst1q(packed_ptr, v0123n_intertwined.val[0]); + packed_ptr += 4; + + vst1q(packed_ptr, v4567n_intertwined.val[0]); + packed_ptr += 4; + + vst1q(packed_ptr, v0123n_intertwined.val[1]); + packed_ptr += 4; + + vst1q(packed_ptr, v4567n_intertwined.val[1]); + packed_ptr += 4; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + data4 += 4; + data5 += 4; + data6 += 4; + data7 += 4; + } + for (index_t d = 0; d < depth_remain; ++d) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q(packed_ptr, vi); + packed_ptr += 4; + + float32x4_t vin = {*data4, *data5, *data6, *data7}; + vst1q(packed_ptr, vin); + packed_ptr += 4; + + ++data0; + ++data1; + ++data2; + ++data3; + ++data4; + ++data5; + ++data6; + ++data7; + } // d + } + } +} + +template +void Gemm::Unpack4x8(const T *packed_output, MatrixMap *output) { + const index_t rows = output->rows(); + const index_t cols = output->cols(); + index_t row_stride = output->rows_stride(); + index_t col_stride = output->cols_stride(); + + T *output_ptr = output->data(); + const T *packed_ptr = packed_output; + + const index_t block_size = 8; + + // packed_output always has row-major + if (output->matrix_major() == RowMajor) { + if (cols < block_size) { + for (index_t r = 0; r < rows; ++r) { + memcpy(output_ptr, packed_ptr, sizeof(T) * cols); + output_ptr += row_stride; + packed_ptr += block_size; + } + } else { + for (index_t r = 0; r < rows; ++r) { + float32x4_t vi = vld1q(packed_ptr); + vst1q(output_ptr, vi); + float32x4_t vin = vld1q(packed_ptr + 4); + vst1q(output_ptr + 4, vin); + + output_ptr += row_stride; + packed_ptr += block_size; + } + } + } else { + // ColMajor + if (rows < block_size) { + for (index_t c = 0; c < cols; ++c) { + for (index_t r = 0; r < rows; ++r) { + output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; + } // r + } // c + } else { + const T *data0 = packed_ptr; + const T *data1 = data0 + block_size; + const T *data2 = data1 + block_size; + const T *data3 = data2 + block_size; + + index_t col_block = cols / 4; + index_t col_remain = cols - col_block * 4; + for (index_t col_block_idx = 0; col_block_idx < col_block; + ++col_block_idx) { + float32x4_t v0 = vld1q(data0); + float32x4_t v1 = vld1q(data1); + float32x4_t v2 = vld1q(data2); + float32x4_t v3 = vld1q(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + vst1q(output_ptr, v0123_intertwined.val[0]); + output_ptr += col_stride; + + vst1q(output_ptr, v0123_intertwined.val[1]); + output_ptr += col_stride; + + vst1q(output_ptr, v0123n_intertwined.val[0]); + output_ptr += col_stride; + + vst1q(output_ptr, v0123n_intertwined.val[1]); + output_ptr += col_stride; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + } + for (index_t c = 0; c < col_remain; ++c) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q(output_ptr, vi); + output_ptr += col_stride; + + ++data0; + ++data1; + ++data2; + ++data3; + } // d + } + } +} + +template +void Gemm::Unpack8x8(const T *packed_output, MatrixMap *output) { + const index_t rows = output->rows(); + const index_t cols = output->cols(); + index_t row_stride = output->rows_stride(); + index_t col_stride = output->cols_stride(); + + T *output_ptr = output->data(); + const T *packed_ptr = packed_output; + + const index_t block_size = 8; + + // packed_output always has row-major + if (output->matrix_major() == RowMajor) { + if (cols < block_size) { + for (index_t r = 0; r < rows; ++r) { + memcpy(output_ptr, packed_ptr, sizeof(T) * cols); + output_ptr += row_stride; + packed_ptr += block_size; + } + } else { + for (index_t r = 0; r < rows; ++r) { + float32x4_t vi = vld1q(packed_ptr); + vst1q(output_ptr, vi); + float32x4_t vin = vld1q(packed_ptr + 4); + vst1q(output_ptr + 4, vin); + + output_ptr += row_stride; + packed_ptr += block_size; + } + } + } else { + // ColMajor + if (rows < block_size) { + for (index_t c = 0; c < cols; ++c) { + for (index_t r = 0; r < rows; ++r) { + output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; + } // r + } // c + } else { + const T *data0 = packed_ptr; + const T *data1 = data0 + block_size; + const T *data2 = data1 + block_size; + const T *data3 = data2 + block_size; + const T *data4 = data3 + block_size; + const T *data5 = data4 + block_size; + const T *data6 = data5 + block_size; + const T *data7 = data6 + block_size; + + index_t col_block = cols / 4; + index_t col_remain = cols - col_block * 4; + for (index_t col_block_idx = 0; col_block_idx < col_block; + ++col_block_idx) { + float32x4_t v0 = vld1q(data0); + float32x4_t v1 = vld1q(data1); + float32x4_t v2 = vld1q(data2); + float32x4_t v3 = vld1q(data3); + float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); + float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); + float32x4x2_t v0123_intertwined = + vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); + float32x4x2_t v0123n_intertwined = + vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); + + float32x4_t v4 = vld1q(data4); + float32x4_t v5 = vld1q(data5); + float32x4_t v6 = vld1q(data6); + float32x4_t v7 = vld1q(data7); + float32x4x2_t v46_intertwined = vzipq_f32(v4, v6); + float32x4x2_t v57_intertwined = vzipq_f32(v5, v7); + float32x4x2_t v4567_intertwined = + vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]); + float32x4x2_t v4567n_intertwined = + vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]); + + vst1q(output_ptr, v0123_intertwined.val[0]); + vst1q(output_ptr + 4, v4567_intertwined.val[0]); + output_ptr += col_stride; + + vst1q(output_ptr, v0123_intertwined.val[1]); + vst1q(output_ptr + 4, v4567_intertwined.val[1]); + output_ptr += col_stride; + + vst1q(output_ptr, v0123n_intertwined.val[0]); + vst1q(output_ptr + 4, v4567n_intertwined.val[0]); + output_ptr += col_stride; + + vst1q(output_ptr, v0123n_intertwined.val[1]); + vst1q(output_ptr + 4, v4567n_intertwined.val[1]); + output_ptr += col_stride; + + data0 += 4; + data1 += 4; + data2 += 4; + data3 += 4; + data4 += 4; + data5 += 4; + data6 += 4; + data7 += 4; + } + for (index_t c = 0; c < col_remain; ++c) { + float32x4_t vi = {*data0, *data1, *data2, *data3}; + vst1q(output_ptr, vi); + float32x4_t vin = {*data4, *data5, *data6, *data7}; + vst1q(output_ptr + 4, vin); + output_ptr += col_stride; + + ++data0; + ++data1; + ++data2; + ++data3; + ++data4; + ++data5; + ++data6; + ++data7; + } // d + } + } +} + +template +void Gemm::PackLhs(const MatrixMap &lhs, T *packed_lhs) { +#ifdef __aarch64__ + Pack8x4(lhs, ColMajor, packed_lhs); +#else + Pack4x4(lhs, ColMajor, packed_lhs); +#endif +} + +template +void Gemm::PackRhs(const MatrixMap &rhs, T *packed_rhs) { + Pack8x4(rhs, RowMajor, packed_rhs); +} + +template +void Gemm::UnpackOutput(const T *packed_output, MatrixMap *output) { +#ifdef __aarch64__ + Unpack8x8(packed_output, output); +#else + Unpack4x8(packed_output, output); +#endif +} + +template +MaceStatus Gemm::Compute( + const OpContext *context, const Tensor *lhs, const Tensor *rhs, + const index_t batch, const index_t rows, const index_t cols, + const index_t depth, const MatrixMajor lhs_major, + const MatrixMajor rhs_major, const MatrixMajor output_major, + const bool lhs_batched, const bool rhs_batched, Tensor *output) { + MACE_CHECK(output->size() == batch * rows * cols, + "Need resize output tensor before call gemm."); + Tensor::MappingGuard lhs_guard(lhs); + Tensor::MappingGuard rhs_guard(rhs); + Tensor::MappingGuard output_guard(output); + const T *lhs_data = lhs->data(); + const T *rhs_data = rhs->data(); + T *output_data = output->mutable_data(); + +#ifdef __aarch64__ + const index_t row_block_size = 8; +#else + const index_t row_block_size = 4; +#endif + const index_t col_block_size = 8; + const index_t depth_block_size = 4; + const index_t row_block_count = RoundUpDiv(rows, row_block_size); + const index_t col_block_count = RoundUpDiv(cols, col_block_size); + const index_t rows_padded = RoundUp(rows, row_block_size); + const index_t cols_padded = RoundUp(cols, col_block_size); + const index_t depth_padded = RoundUp(depth, depth_block_size); + + ScratchBuffer *scratch = context->device()->scratch_buffer(); + + index_t packed_lhs_size = + PadAlignSize(sizeof(T) * rows_padded * depth_padded); + index_t packed_rhs_size = + PadAlignSize(sizeof(T) * depth_padded * cols_padded); + index_t packed_output_size = + PadAlignSize(sizeof(T) * rows_padded * cols_padded); + // resize to the total size of lhs & rhs & output anyway, + // in case we do not cache const tensor for saving memory + MACE_RETURN_IF_ERROR(scratch->GrowSize( + packed_lhs_size + packed_rhs_size + packed_output_size)); + T *packed_lhs_data = + scratch->Scratch(packed_lhs_size).mutable_data(); + T *packed_rhs_data = + scratch->Scratch(packed_rhs_size).mutable_data(); + T *packed_output_data = + scratch->Scratch(packed_output_size).mutable_data(); + + int cache_side = kNoCache; + if (cached_ == kCacheLhs) { + packed_lhs_data = pack_cache_.mutable_data(); + } else if (cached_ == kCacheRhs) { + packed_rhs_data = pack_cache_.mutable_data(); + } else if (should_cache_pack_) { + if (lhs->is_weight() && (!lhs_batched || batch == 1)) { + cache_side = kCacheLhs; + pack_cache_.Resize(packed_lhs_size); + packed_lhs_data = pack_cache_.mutable_data(); + } else if (rhs->is_weight() && (!rhs_batched || batch == 1)) { + cache_side = kCacheRhs; + pack_cache_.Resize(packed_rhs_size); + packed_rhs_data = pack_cache_.mutable_data(); + } + } + + utils::ThreadPool + &thread_pool = context->device()->cpu_runtime()->thread_pool(); + + for (index_t b = 0; b < batch; ++b) { + MatrixMap + lhs_matrix + (lhs_data + static_cast(lhs_batched) * b * rows * depth, + lhs_major, + rows, + depth); + MatrixMap + rhs_matrix + (rhs_data + static_cast(rhs_batched) * b * depth * cols, + rhs_major, + depth, + cols); + MatrixMap output_matrix + (output_data + b * rows * cols, output_major, rows, cols); + + // pack lhs + if (cached_ != kCacheLhs) { + thread_pool.Compute1D([=, &lhs_matrix](index_t start, + index_t end, + index_t step) { + for (index_t row_block_idx = start; row_block_idx < end; + row_block_idx += step) { + const index_t start_row = row_block_idx * row_block_size; + const index_t + row_block_len = std::min(row_block_size, rows - start_row); + T *packed_lhs_data_block = + packed_lhs_data + row_block_idx * row_block_size * depth_padded; + PackLhs(lhs_matrix.block(start_row, 0, row_block_len, depth), + packed_lhs_data_block); + } + }, 0, row_block_count, 1); + + if (cache_side == kCacheLhs) { + cached_ = kCacheLhs; + if (lhs->UnderlyingBuffer()->OnHost()) { + AdviseFree(reinterpret_cast(const_cast(lhs->data< + T>())), + lhs->raw_size()); + } + } + } + + // pack rhs + if (cached_ != kCacheRhs) { + thread_pool.Compute1D([=, &rhs_matrix](index_t start, + index_t end, + index_t step) { + for (index_t col_block_idx = start; col_block_idx < end; + col_block_idx += step) { + const index_t start_col = col_block_idx * col_block_size; + const index_t + col_block_len = std::min(col_block_size, cols - start_col); + T *packed_rhs_data_block = + packed_rhs_data + col_block_idx * col_block_size * depth_padded; + PackRhs(rhs_matrix.block(0, start_col, depth, col_block_len), + packed_rhs_data_block); + } + }, 0, col_block_count, 1); + + if (cache_side == kCacheRhs) { + cached_ = kCacheRhs; + if (rhs->UnderlyingBuffer()->OnHost()) { + AdviseFree(reinterpret_cast(const_cast(rhs->data< + T>())), + rhs->raw_size()); + } + } + } + + // multiply lhs and rhs + thread_pool.Compute1D([=, &output_matrix](index_t start, + index_t end, + index_t step) { + for (index_t row_block_idx = start; row_block_idx < end; + row_block_idx += step) { + const index_t start_row = row_block_idx * row_block_size; + const index_t + row_block_len = std::min(row_block_size, rows - start_row); + const T *packed_lhs_data_block = + packed_lhs_data + row_block_idx * row_block_size * depth_padded; + + for (index_t col_block_idx = 0; col_block_idx < col_block_count; + ++col_block_idx) { + const index_t start_col = col_block_idx * col_block_size; + const index_t + col_block_len = std::min(col_block_size, cols - start_col); + const T *packed_rhs_data_block = + packed_rhs_data + col_block_idx * col_block_size * depth_padded; + T *packed_output_data_block = + packed_output_data + row_block_idx * row_block_size * cols_padded + + col_block_idx * col_block_size; + ComputeBlock(packed_lhs_data_block, + packed_rhs_data_block, + depth_padded, + packed_output_data_block); + MatrixMap output_block = output_matrix.block(start_row, + start_col, + row_block_len, + col_block_len); + UnpackOutput(packed_output_data_block, &output_block); + } // col_block_idx + } // row_block_idx + }, 0, row_block_count, 1); + } // b + + return MaceStatus::MACE_SUCCESS; +} + void RegisterGemmDelegator(OpDelegatorRegistry *registry) { MACE_REGISTER_DELEGATOR( registry, Gemm, delegator::GemmParam, MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, float, ImplType::NEON)); + + MACE_REGISTER_BF16_DELEGATOR( + registry, Gemm, delegator::GemmParam, + MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, BFloat16, ImplType::NEON)); } } // namespace arm diff --git a/mace/ops/arm/base/gemm.h b/mace/ops/arm/base/gemm.h index ec6cc318..a5fda739 100644 --- a/mace/ops/arm/base/gemm.h +++ b/mace/ops/arm/base/gemm.h @@ -110,46 +110,16 @@ class Gemm : public delegator::Gemm { void UnpackOutput(const T *packed_output, MatrixMap *output); - template - void Unpack(const T *packed_output, - MatrixMap *output) { - const index_t rows = output->rows(); - const index_t cols = output->cols(); - for (index_t r = 0; r < rows; ++r) { - for (index_t c = 0; c < cols; ++c) { - *output->data(r, c) = packed_output[r * ColBlockSize + c]; - } - } - } - template - void Pack(const MatrixMap &matrix, - MatrixMajor dst_major, - T *packed_matrix) { - const index_t rows = matrix.rows(); - const index_t cols = matrix.cols(); - index_t depth = cols; - if (dst_major == RowMajor) { - // rhs - depth = rows; - } - const index_t depth_padded = RoundUp(depth, static_cast(4)); - memset(static_cast(packed_matrix), 0, - sizeof(T) * WidthBlockSize * depth_padded); - if (dst_major == ColMajor) { - for (index_t c = 0; c < cols; ++c) { - for (index_t r = 0; r < rows; ++r) { - packed_matrix[c * WidthBlockSize + r] = matrix(r, c); - } - } - } else { - for (index_t r = 0; r < rows; ++r) { - for (index_t c = 0; c < cols; ++c) { - packed_matrix[r * WidthBlockSize + c] = matrix(r, c); - } - } - } - } + void Unpack4x8(const T *packed_output, MatrixMap *output); + void Unpack8x8(const T *packed_output, MatrixMap *output); + + void Pack4x4(const MatrixMap &matrix, + MatrixMajor dst_major, + T *packed_matrix); + void Pack8x4(const MatrixMap &matrix, + MatrixMajor dst_major, + T *packed_matrix); private: Buffer pack_cache_; diff --git a/mace/ops/arm/bf16/conv_2d_3x3.cc b/mace/ops/arm/bf16/conv_2d_3x3.cc new file mode 100644 index 00000000..1b4d34ab --- /dev/null +++ b/mace/ops/arm/bf16/conv_2d_3x3.cc @@ -0,0 +1,459 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/arm/base/conv_2d_3x3.h" + +#include +#include +#include + +#include "mace/ops/arm/base/common_neon.h" +#include "mace/ops/delegator/conv_2d.h" + +namespace mace { +namespace ops { +namespace arm { + +template <> +MaceStatus Conv2dK3x3S1::DoCompute( + const ConvComputeParam &p, const BFloat16 *filter_data, + const BFloat16 *input_data, BFloat16 *output_data) { + p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + for (index_t m = start1; m < end1; m += step1) { + if (m + 1 < p.out_channels) { + auto in_ptr0_base = input_data + b * p.in_batch_size; + auto in_ptr1_base = in_ptr0_base + p.in_width; + auto in_ptr2_base = in_ptr1_base + p.in_width; + auto in_ptr3_base = in_ptr2_base + p.in_width; + auto out_ptr0 = output_data + b * p.out_batch_size + + m * p.out_image_size; + auto out_ptr1 = out_ptr0 + p.out_image_size; + for (index_t h = 0; h + 1 < p.out_height; h += 2) { + for (index_t w = 0; w + 3 < p.out_width; w += 4) { + auto in_ptr0 = in_ptr0_base; + auto in_ptr1 = in_ptr1_base; + auto in_ptr2 = in_ptr2_base; + auto in_ptr3 = in_ptr3_base; + auto filter_ptr0 = filter_data + m * p.in_channels * 9; + auto filter_ptr1 = filter_ptr0 + p.in_channels * 9; + float32x4_t vo00 = vdupq_n_f32(0.f); + float32x4_t vo01 = vdupq_n_f32(0.f); + float32x4_t vo10 = vdupq_n_f32(0.f); + float32x4_t vo11 = vdupq_n_f32(0.f); + for (index_t c = 0; c < p.in_channels; ++c) { + // input (4 height x 3 slide): vi_height_slide + float32x4_t vi00, vi01, vi02, vi0n; + float32x4_t vi10, vi11, vi12, vi1n; + float32x4_t vi20, vi21, vi22, vi2n; + float32x4_t vi30, vi31, vi32, vi3n; + + // load input + vi00 = vld1q_bf16(in_ptr0); + vi0n = vld1q_bf16(in_ptr0 + 4); + vi10 = vld1q_bf16(in_ptr1); + vi1n = vld1q_bf16(in_ptr1 + 4); + vi20 = vld1q_bf16(in_ptr2); + vi2n = vld1q_bf16(in_ptr2 + 4); + vi30 = vld1q_bf16(in_ptr3); + vi3n = vld1q_bf16(in_ptr3 + 4); + + vi01 = vextq_f32(vi00, vi0n, 1); + vi02 = vextq_f32(vi00, vi0n, 2); + vi11 = vextq_f32(vi10, vi1n, 1); + vi12 = vextq_f32(vi10, vi1n, 2); + vi21 = vextq_f32(vi20, vi2n, 1); + vi22 = vextq_f32(vi20, vi2n, 2); + vi31 = vextq_f32(vi30, vi3n, 1); + vi32 = vextq_f32(vi30, vi3n, 2); + +#if defined(__aarch64__) + // load filter (2 outch x 3 height x 3 width): + // vf_outch_height + float32x4_t vf00, vf01, vf02; + float32x4_t vf10, vf11, vf12; + vf00 = vld1q_bf16(filter_ptr0); + vf01 = vld1q_bf16(filter_ptr0 + 3); + vf02 = vld1q_bf16(filter_ptr0 + 6); + vf10 = vld1q_bf16(filter_ptr1); + vf11 = vld1q_bf16(filter_ptr1 + 3); + vf12 = vld1q_bf16(filter_ptr1 + 6); + + // outch 0, height 0 + vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0); // reg count: 18 + vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1); + vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2); + vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0); + vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1); + vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2); + vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 0); + vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 1); + vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 2); + + // outch 0, height 1 + vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0); + vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1); + vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2); + vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0); + vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1); + vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2); + vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 0); + vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 1); + vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 2); + + // outch 1, height 0 + vo10 = vfmaq_laneq_f32(vo10, vi00, vf10, 0); + vo10 = vfmaq_laneq_f32(vo10, vi01, vf10, 1); + vo10 = vfmaq_laneq_f32(vo10, vi02, vf10, 2); + vo10 = vfmaq_laneq_f32(vo10, vi10, vf11, 0); + vo10 = vfmaq_laneq_f32(vo10, vi11, vf11, 1); + vo10 = vfmaq_laneq_f32(vo10, vi12, vf11, 2); + vo10 = vfmaq_laneq_f32(vo10, vi20, vf12, 0); + vo10 = vfmaq_laneq_f32(vo10, vi21, vf12, 1); + vo10 = vfmaq_laneq_f32(vo10, vi22, vf12, 2); + + // outch 1, height 1 + vo11 = vfmaq_laneq_f32(vo11, vi10, vf10, 0); + vo11 = vfmaq_laneq_f32(vo11, vi11, vf10, 1); + vo11 = vfmaq_laneq_f32(vo11, vi12, vf10, 2); + vo11 = vfmaq_laneq_f32(vo11, vi20, vf11, 0); + vo11 = vfmaq_laneq_f32(vo11, vi21, vf11, 1); + vo11 = vfmaq_laneq_f32(vo11, vi22, vf11, 2); + vo11 = vfmaq_laneq_f32(vo11, vi30, vf12, 0); + vo11 = vfmaq_laneq_f32(vo11, vi31, vf12, 1); + vo11 = vfmaq_laneq_f32(vo11, vi32, vf12, 2); +#else + float32x2_t vf001, vf023, vf045, vf067, vf089; + float32x2_t vf101, vf123, vf145, vf167, vf189; + vf001 = vld1_bf16(filter_ptr0); + vf023 = vld1_bf16(filter_ptr0 + 2); + vf045 = vld1_bf16(filter_ptr0 + 4); + vf067 = vld1_bf16(filter_ptr0 + 6); + vf089 = vld1_bf16(filter_ptr0 + 8); + + vf101 = vld1_bf16(filter_ptr1); + vf123 = vld1_bf16(filter_ptr1 + 2); + vf145 = vld1_bf16(filter_ptr1 + 4); + vf167 = vld1_bf16(filter_ptr1 + 6); + vf189 = vld1_bf16(filter_ptr1 + 8); + + // outch 0, height 0 + vo00 = vmlaq_lane_f32(vo00, vi00, vf001, 0); + vo00 = vmlaq_lane_f32(vo00, vi01, vf001, 1); + vo00 = vmlaq_lane_f32(vo00, vi02, vf023, 0); + vo00 = vmlaq_lane_f32(vo00, vi10, vf023, 1); + vo00 = vmlaq_lane_f32(vo00, vi11, vf045, 0); + vo00 = vmlaq_lane_f32(vo00, vi12, vf045, 1); + vo00 = vmlaq_lane_f32(vo00, vi20, vf067, 0); + vo00 = vmlaq_lane_f32(vo00, vi21, vf067, 1); + vo00 = vmlaq_lane_f32(vo00, vi22, vf089, 0); + + // outch 0, height 1 + vo01 = vmlaq_lane_f32(vo01, vi10, vf001, 0); + vo01 = vmlaq_lane_f32(vo01, vi11, vf001, 1); + vo01 = vmlaq_lane_f32(vo01, vi12, vf023, 0); + vo01 = vmlaq_lane_f32(vo01, vi20, vf023, 1); + vo01 = vmlaq_lane_f32(vo01, vi21, vf045, 0); + vo01 = vmlaq_lane_f32(vo01, vi22, vf045, 1); + vo01 = vmlaq_lane_f32(vo01, vi30, vf067, 0); + vo01 = vmlaq_lane_f32(vo01, vi31, vf067, 1); + vo01 = vmlaq_lane_f32(vo01, vi32, vf089, 0); + + // outch 1, height 0 + vo10 = vmlaq_lane_f32(vo10, vi00, vf101, 0); + vo10 = vmlaq_lane_f32(vo10, vi01, vf101, 1); + vo10 = vmlaq_lane_f32(vo10, vi02, vf123, 0); + vo10 = vmlaq_lane_f32(vo10, vi10, vf123, 1); + vo10 = vmlaq_lane_f32(vo10, vi11, vf145, 0); + vo10 = vmlaq_lane_f32(vo10, vi12, vf145, 1); + vo10 = vmlaq_lane_f32(vo10, vi20, vf167, 0); + vo10 = vmlaq_lane_f32(vo10, vi21, vf167, 1); + vo10 = vmlaq_lane_f32(vo10, vi22, vf189, 0); + + // outch 1, height 1 + vo11 = vmlaq_lane_f32(vo11, vi10, vf101, 0); + vo11 = vmlaq_lane_f32(vo11, vi11, vf101, 1); + vo11 = vmlaq_lane_f32(vo11, vi12, vf123, 0); + vo11 = vmlaq_lane_f32(vo11, vi20, vf123, 1); + vo11 = vmlaq_lane_f32(vo11, vi21, vf145, 0); + vo11 = vmlaq_lane_f32(vo11, vi22, vf145, 1); + vo11 = vmlaq_lane_f32(vo11, vi30, vf167, 0); + vo11 = vmlaq_lane_f32(vo11, vi31, vf167, 1); + vo11 = vmlaq_lane_f32(vo11, vi32, vf189, 0); +#endif + in_ptr0 += p.in_image_size; + in_ptr1 += p.in_image_size; + in_ptr2 += p.in_image_size; + in_ptr3 += p.in_image_size; + filter_ptr0 += 9; + filter_ptr1 += 9; + } + vst1q_bf16(out_ptr0, vo00); + vst1q_bf16(out_ptr0 + p.out_width, vo01); + vst1q_bf16(out_ptr1, vo10); + vst1q_bf16(out_ptr1 + p.out_width, vo11); + + in_ptr0_base += 4; + in_ptr1_base += 4; + in_ptr2_base += 4; + in_ptr3_base += 4; + + out_ptr0 += 4; + out_ptr1 += 4; + } + in_ptr0_base += 2 + p.in_width; + in_ptr1_base += 2 + p.in_width; + in_ptr2_base += 2 + p.in_width; + in_ptr3_base += 2 + p.in_width; + + out_ptr0 += p.out_width; + out_ptr1 += p.out_width; + } + } else { + for (index_t mm = m; mm < p.out_channels; ++mm) { + auto out_ptr0 = output_data + b * p.out_batch_size + + mm * p.out_image_size; + auto in_ptr0_base = input_data + b * p.in_batch_size; + auto in_ptr1_base = in_ptr0_base + p.in_width; + auto in_ptr2_base = in_ptr1_base + p.in_width; + auto in_ptr3_base = in_ptr2_base + p.in_width; + for (index_t h = 0; h + 1 < p.out_height; h += 2) { + for (index_t w = 0; w + 3 < p.out_width; w += 4) { + auto in_ptr0 = in_ptr0_base; + auto in_ptr1 = in_ptr1_base; + auto in_ptr2 = in_ptr2_base; + auto in_ptr3 = in_ptr3_base; + auto filter_ptr0 = filter_data + mm * p.in_channels * 9; + float32x4_t vo00 = vdupq_n_f32(0.f); + float32x4_t vo01 = vdupq_n_f32(0.f); + for (index_t c = 0; c < p.in_channels; ++c) { + // input (4 height x 3 slide): vi_height_slide + float32x4_t vi00, vi01, vi02, vi0n; + float32x4_t vi10, vi11, vi12, vi1n; + float32x4_t vi20, vi21, vi22, vi2n; + float32x4_t vi30, vi31, vi32, vi3n; + + // load input + vi00 = vld1q_bf16(in_ptr0); + vi0n = vld1q_bf16(in_ptr0 + 4); + vi10 = vld1q_bf16(in_ptr1); + vi1n = vld1q_bf16(in_ptr1 + 4); + vi20 = vld1q_bf16(in_ptr2); + vi2n = vld1q_bf16(in_ptr2 + 4); + vi30 = vld1q_bf16(in_ptr3); + vi3n = vld1q_bf16(in_ptr3 + 4); + + vi01 = vextq_f32(vi00, vi0n, 1); + vi02 = vextq_f32(vi00, vi0n, 2); + vi11 = vextq_f32(vi10, vi1n, 1); + vi12 = vextq_f32(vi10, vi1n, 2); + vi21 = vextq_f32(vi20, vi2n, 1); + vi22 = vextq_f32(vi20, vi2n, 2); + vi31 = vextq_f32(vi30, vi3n, 1); + vi32 = vextq_f32(vi30, vi3n, 2); + +#if defined(__aarch64__) + // load filter (1 outch x 3 height x 3 width): vf_outch_height + float32x4_t vf00, vf01, vf02; + vf00 = vld1q_bf16(filter_ptr0); + vf01 = vld1q_bf16(filter_ptr0 + 3); + vf02 = vld1q_bf16(filter_ptr0 + 5); + + // outch 0, height 0 + vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0); + vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1); + vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2); + vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0); + vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1); + vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2); + vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 1); + vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 2); + vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 3); + + // outch 0, height 1 + vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0); + vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1); + vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2); + vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0); + vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1); + vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2); + vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 1); + vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 2); + vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 3); +#else + // load filter (1 outch x 3 height x 3 width): vf_outch_height + float32x2_t vf01, vf23, vf45, vf67, vf78; + vf01 = vld1_bf16(filter_ptr0); + vf23 = vld1_bf16(filter_ptr0 + 2); + vf45 = vld1_bf16(filter_ptr0 + 4); + vf67 = vld1_bf16(filter_ptr0 + 6); + vf78 = vld1_bf16(filter_ptr0 + 7); + + // outch 0, height 0 + vo00 = vmlaq_lane_f32(vo00, vi00, vf01, 0); + vo00 = vmlaq_lane_f32(vo00, vi01, vf01, 1); + vo00 = vmlaq_lane_f32(vo00, vi02, vf23, 0); + vo00 = vmlaq_lane_f32(vo00, vi10, vf23, 1); + vo00 = vmlaq_lane_f32(vo00, vi11, vf45, 0); + vo00 = vmlaq_lane_f32(vo00, vi12, vf45, 1); + vo00 = vmlaq_lane_f32(vo00, vi20, vf67, 0); + vo00 = vmlaq_lane_f32(vo00, vi21, vf67, 1); + vo00 = vmlaq_lane_f32(vo00, vi22, vf78, 1); + + // outch 0, height 1 + vo01 = vmlaq_lane_f32(vo01, vi10, vf01, 0); + vo01 = vmlaq_lane_f32(vo01, vi11, vf01, 1); + vo01 = vmlaq_lane_f32(vo01, vi12, vf23, 0); + vo01 = vmlaq_lane_f32(vo01, vi20, vf23, 1); + vo01 = vmlaq_lane_f32(vo01, vi21, vf45, 0); + vo01 = vmlaq_lane_f32(vo01, vi22, vf45, 1); + vo01 = vmlaq_lane_f32(vo01, vi30, vf67, 0); + vo01 = vmlaq_lane_f32(vo01, vi31, vf67, 1); + vo01 = vmlaq_lane_f32(vo01, vi32, vf78, 1); + +#endif + in_ptr0 += p.in_image_size; + in_ptr1 += p.in_image_size; + in_ptr2 += p.in_image_size; + in_ptr3 += p.in_image_size; + filter_ptr0 += 9; + } // c + + vst1q_bf16(out_ptr0, vo00); + vst1q_bf16(out_ptr0 + p.out_width, vo01); + + in_ptr0_base += 4; + in_ptr1_base += 4; + in_ptr2_base += 4; + in_ptr3_base += 4; + + out_ptr0 += 4; + } // w + + in_ptr0_base += 2 + p.in_width; + in_ptr1_base += 2 + p.in_width; + in_ptr2_base += 2 + p.in_width; + in_ptr3_base += 2 + p.in_width; + + out_ptr0 += p.out_width; + } // h + } // mm + } // if + } // m + } // b + }, 0, p.batch, 1, 0, p.out_channels, 2); + + return MaceStatus::MACE_SUCCESS; +} + +template <> +MaceStatus Conv2dK3x3S2::DoCompute( + const ConvComputeParam &p, const BFloat16 *filter_data, + const BFloat16 *input_data, BFloat16 *output_data) { + p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, + index_t start1, index_t end1, index_t step1) { + for (index_t b = start0; b < end0; b += step0) { + for (index_t m = start1; m < end1; m += step1) { + auto out_base = output_data + b * p.out_batch_size + + m * p.out_image_size; + for (index_t h = 0; h < p.out_height; ++h) { + for (index_t w = 0; w + 3 < p.out_width; w += 4) { + // offset + const index_t in_h = h * 2; + const index_t in_w = w * 2; + const index_t in_offset = in_h * p.in_width + in_w; + const index_t out_offset = h * p.out_width + w; + // output (1 outch x 1 height x 4 width): vo + float32x4_t vo = vdupq_n_f32(0.f); + auto in_base = input_data + b * p.in_batch_size; + auto f_ptr = filter_data + m * p.in_channels * 9; + for (index_t c = 0; c < p.in_channels; ++c) { + // input (3 height x 3 slide): vi_height_slide + float32x4x2_t vi0, vi1, vi2; + float32x4_t vi0n, vi1n, vi2n; + float32x4_t vi00, vi01, vi02; + float32x4_t vi10, vi11, vi12; + float32x4_t vi20, vi21, vi22; + + // load input + vi0 = vld2q_bf16(in_base + in_offset); + vi1 = vld2q_bf16(in_base + in_offset + p.in_width); + vi2 = vld2q_bf16(in_base + in_offset + 2 * p.in_width); + + vi0n = vld1q_bf16(in_base + in_offset + 8); + vi1n = vld1q_bf16(in_base + in_offset + p.in_width + 8); + vi2n = vld1q_bf16(in_base + in_offset + 2 * p.in_width + 8); + + vi00 = vi0.val[0]; // [0.2.4.6] + vi01 = vi0.val[1]; // [1.3.5.7] + vi02 = vextq_f32(vi00, vi0n, 1); // [2.4.6.8] + vi10 = vi1.val[0]; + vi11 = vi1.val[1]; + vi12 = vextq_f32(vi10, vi1n, 1); + vi20 = vi2.val[0]; + vi21 = vi2.val[1]; + vi22 = vextq_f32(vi20, vi2n, 1); + +#if defined(__aarch64__) // arm v8 + // load filter (1 outch x 3 height x 3 width): vf_outch_height + float32x4_t vf00 = vld1q_bf16(f_ptr); + float32x4_t vf01 = vld1q_bf16(f_ptr + 3); + float32x4_t vf02 = vld1q_bf16(f_ptr + 5); + + // outch 0, height 0 + vo = vfmaq_laneq_f32(vo, vi00, vf00, 0); + vo = vfmaq_laneq_f32(vo, vi01, vf00, 1); + vo = vfmaq_laneq_f32(vo, vi02, vf00, 2); + vo = vfmaq_laneq_f32(vo, vi10, vf01, 0); + vo = vfmaq_laneq_f32(vo, vi11, vf01, 1); + vo = vfmaq_laneq_f32(vo, vi12, vf01, 2); + vo = vfmaq_laneq_f32(vo, vi20, vf02, 1); + vo = vfmaq_laneq_f32(vo, vi21, vf02, 2); + vo = vfmaq_laneq_f32(vo, vi22, vf02, 3); +#else // arm v7 + // load filter (1 outch x 3 height x 3 width): vf_outch_height + float32x2_t vf01 = vld1_bf16(f_ptr); + float32x2_t vf23 = vld1_bf16(f_ptr + 2); + float32x2_t vf45 = vld1_bf16(f_ptr + 4); + float32x2_t vf67 = vld1_bf16(f_ptr + 6); + float32x2_t vf78 = vld1_bf16(f_ptr + 7); + + // outch 0, height 0 + vo = vmlaq_lane_f32(vo, vi00, vf01, 0); + vo = vmlaq_lane_f32(vo, vi01, vf01, 1); + vo = vmlaq_lane_f32(vo, vi02, vf23, 0); + vo = vmlaq_lane_f32(vo, vi10, vf23, 1); + vo = vmlaq_lane_f32(vo, vi11, vf45, 0); + vo = vmlaq_lane_f32(vo, vi12, vf45, 1); + vo = vmlaq_lane_f32(vo, vi20, vf67, 0); + vo = vmlaq_lane_f32(vo, vi21, vf67, 1); + vo = vmlaq_lane_f32(vo, vi22, vf78, 1); +#endif + in_base += p.in_image_size; + f_ptr += 9; + } + vst1q_bf16(out_base + out_offset, vo); + } + } + } // m + } // b + }, 0, p.batch, 1, 0, p.out_channels, 1); + + return MaceStatus::MACE_SUCCESS; +} + +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/bf16/gemm.cc b/mace/ops/arm/bf16/gemm.cc new file mode 100644 index 00000000..c8bfe58e --- /dev/null +++ b/mace/ops/arm/bf16/gemm.cc @@ -0,0 +1,535 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/arm/base/gemm.h" + +#include +#include +#include + +#include "mace/port/env.h" + +namespace mace { +namespace ops { +namespace arm { + +template <> +void Gemm::ComputeBlock(const BFloat16 *packed_lhs_data, + const BFloat16 *packed_rhs_data, + const index_t depth_padded, + BFloat16 *packed_output_data) { + const BFloat16 *lhs_ptr = packed_lhs_data; + const BFloat16 *rhs_ptr = packed_rhs_data; + + const index_t depth_block_count = depth_padded / 4; + +#ifdef __aarch64__ + // Register layout: (8x4) x (4,8) + // + // +--------+--------+ + // | v8 ... | v9 ... | + // Rhs +--------+--------+ + // | v10... | v11... | + // +--------+--------+ + // | v12... | v13... | + // +--------+--------+ + // | v14... | v15... | + // +--------+--------+ + // + // Lhs + // + // +----+----+----+----+ - - +--------+--------+ + // | v0 | v2 | v4 | v6 | | v16... | v17... | + // | . | | | | | v18... | v19... | + // | . | | | | | v20... | v21... | + // | . | | | | | v22... | v23... | + // +----+----|----+----+ +--------+--------+ + // | v1 | v3 | v5 | v7 | | v24... | v25... | + // | . | | | | | v26... | v27... | + // | . | | | | | v28... | v29... | + // | . | | | | | v30... | v31... | + // +----+----|----+----+ +--------+--------+ + // + // Accumulator + // + + if (depth_block_count > 0) { + index_t r_depth_block_count = depth_block_count; + // just make compiler happy + MACE_UNUSED(r_depth_block_count); + + asm volatile( + "dup v16.4s, wzr \n" + "dup v17.4s, wzr \n" + "dup v18.4s, wzr \n" + "dup v19.4s, wzr \n" + "dup v20.4s, wzr \n" + "dup v21.4s, wzr \n" + "dup v22.4s, wzr \n" + "dup v23.4s, wzr \n" + "dup v24.4s, wzr \n" + "dup v25.4s, wzr \n" + "dup v26.4s, wzr \n" + "dup v27.4s, wzr \n" + "dup v28.4s, wzr \n" + "dup v29.4s, wzr \n" + "dup v30.4s, wzr \n" + "dup v31.4s, wzr \n" + + // prelogue + "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%[lhs_ptr]], #32 \n" + "shll v0.4s, v0.4h, #16 \n" + "shll v1.4s, v1.4h, #16 \n" + "shll v2.4s, v2.4h, #16 \n" + "shll v3.4s, v3.4h, #16 \n" + + "ld1 {v4.4h, v5.4h, v6.4h, v7.4h}, [%[lhs_ptr]], #32 \n" + "shll v4.4s, v4.4h, #16 \n" + "shll v5.4s, v5.4h, #16 \n" + "shll v6.4s, v6.4h, #16 \n" + "shll v7.4s, v7.4h, #16 \n" + + "ld1 {v8.4h, v9.4h, v10.4h, v11.4h}, [%[rhs_ptr]], #32 \n" + "shll v8.4s, v8.4h, #16 \n" + "shll v9.4s, v9.4h, #16 \n" + "shll v10.4s, v10.4h, #16 \n" + "shll v11.4s, v11.4h, #16 \n" + + "ld1 {v12.4h, v13.4h, v14.4h, v15.4h}, [%[rhs_ptr]], #32 \n" + "shll v12.4s, v12.4h, #16 \n" + "shll v13.4s, v13.4h, #16 \n" + "shll v14.4s, v14.4h, #16 \n" + "shll v15.4s, v15.4h, #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + "beq 1f\n" + + "0: \n" + "fmla v16.4s, v8.4s, v0.s[0] \n" + "fmla v17.4s, v9.4s, v0.s[0] \n" + "fmla v18.4s, v8.4s, v0.s[1] \n" + "fmla v19.4s, v9.4s, v0.s[1] \n" + "fmla v20.4s, v8.4s, v0.s[2] \n" + "fmla v21.4s, v9.4s, v0.s[2] \n" + "fmla v22.4s, v8.4s, v0.s[3] \n" + "fmla v23.4s, v9.4s, v0.s[3] \n" + + "ld1 {v0.4h}, [%[lhs_ptr]], #8 \n" + "shll v0.4s, v0.4h, #16 \n" + + "fmla v24.4s, v8.4s, v1.s[0] \n" + "fmla v25.4s, v9.4s, v1.s[0] \n" + "fmla v26.4s, v8.4s, v1.s[1] \n" + "fmla v27.4s, v9.4s, v1.s[1] \n" + "fmla v28.4s, v8.4s, v1.s[2] \n" + "fmla v29.4s, v9.4s, v1.s[2] \n" + "fmla v30.4s, v8.4s, v1.s[3] \n" + "fmla v31.4s, v9.4s, v1.s[3] \n" + + "ld1 {v1.4h}, [%[lhs_ptr]], #8 \n" + "shll v1.4s, v1.4h, #16 \n" + "ld1 {v8.4h, v9.4h}, [%[rhs_ptr]], #16 \n" + "shll v8.4s, v8.4h, #16 \n" + "shll v9.4s, v9.4h, #16 \n" + + "fmla v16.4s, v10.4s, v2.s[0] \n" + "fmla v17.4s, v11.4s, v2.s[0] \n" + "fmla v18.4s, v10.4s, v2.s[1] \n" + "fmla v19.4s, v11.4s, v2.s[1] \n" + "fmla v20.4s, v10.4s, v2.s[2] \n" + "fmla v21.4s, v11.4s, v2.s[2] \n" + "fmla v22.4s, v10.4s, v2.s[3] \n" + "fmla v23.4s, v11.4s, v2.s[3] \n" + + "ld1 {v2.4h}, [%[lhs_ptr]], #8 \n" + "shll v2.4s, v2.4h, #16 \n" + + "fmla v24.4s, v10.4s, v3.s[0] \n" + "fmla v25.4s, v11.4s, v3.s[0] \n" + "fmla v26.4s, v10.4s, v3.s[1] \n" + "fmla v27.4s, v11.4s, v3.s[1] \n" + "fmla v28.4s, v10.4s, v3.s[2] \n" + "fmla v29.4s, v11.4s, v3.s[2] \n" + "fmla v30.4s, v10.4s, v3.s[3] \n" + "fmla v31.4s, v11.4s, v3.s[3] \n" + + "ld1 {v3.4h}, [%[lhs_ptr]], #8 \n" + "shll v3.4s, v3.4h, #16 \n" + "ld1 {v10.4h, v11.4h}, [%[rhs_ptr]], #16 \n" + "shll v10.4s, v10.4h, #16 \n" + "shll v11.4s, v11.4h, #16 \n" + + "fmla v16.4s, v12.4s, v4.s[0] \n" + "fmla v17.4s, v13.4s, v4.s[0] \n" + "fmla v18.4s, v12.4s, v4.s[1] \n" + "fmla v19.4s, v13.4s, v4.s[1] \n" + "fmla v20.4s, v12.4s, v4.s[2] \n" + "fmla v21.4s, v13.4s, v4.s[2] \n" + "fmla v22.4s, v12.4s, v4.s[3] \n" + "fmla v23.4s, v13.4s, v4.s[3] \n" + + "ld1 {v4.4h}, [%[lhs_ptr]], #8 \n" + "shll v4.4s, v4.4h, #16 \n" + + "fmla v24.4s, v12.4s, v5.s[0] \n" + "fmla v25.4s, v13.4s, v5.s[0] \n" + "fmla v26.4s, v12.4s, v5.s[1] \n" + "fmla v27.4s, v13.4s, v5.s[1] \n" + "fmla v28.4s, v12.4s, v5.s[2] \n" + "fmla v29.4s, v13.4s, v5.s[2] \n" + "fmla v30.4s, v12.4s, v5.s[3] \n" + "fmla v31.4s, v13.4s, v5.s[3] \n" + + "ld1 {v5.4h}, [%[lhs_ptr]], #8 \n" + "shll v5.4s, v5.4h, #16 \n" + "ld1 {v12.4h, v13.4h}, [%[rhs_ptr]], #16 \n" + "shll v12.4s, v12.4h, #16 \n" + "shll v13.4s, v13.4h, #16 \n" + + "fmla v16.4s, v14.4s, v6.s[0] \n" + "fmla v17.4s, v15.4s, v6.s[0] \n" + "fmla v18.4s, v14.4s, v6.s[1] \n" + "fmla v19.4s, v15.4s, v6.s[1] \n" + "fmla v20.4s, v14.4s, v6.s[2] \n" + "fmla v21.4s, v15.4s, v6.s[2] \n" + "fmla v22.4s, v14.4s, v6.s[3] \n" + "fmla v23.4s, v15.4s, v6.s[3] \n" + + "ld1 {v6.4h}, [%[lhs_ptr]], #8 \n" + "shll v6.4s, v6.4h, #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + + "fmla v24.4s, v14.4s, v7.s[0] \n" + "fmla v25.4s, v15.4s, v7.s[0] \n" + "fmla v26.4s, v14.4s, v7.s[1] \n" + "fmla v27.4s, v15.4s, v7.s[1] \n" + "fmla v28.4s, v14.4s, v7.s[2] \n" + "fmla v29.4s, v15.4s, v7.s[2] \n" + "fmla v30.4s, v14.4s, v7.s[3] \n" + "fmla v31.4s, v15.4s, v7.s[3] \n" + + "ld1 {v7.4h}, [%[lhs_ptr]], #8 \n" + "shll v7.4s, v7.4h, #16 \n" + "ld1 {v14.4h, v15.4h}, [%[rhs_ptr]], #16 \n" + "shll v14.4s, v14.4h, #16 \n" + "shll v15.4s, v15.4h, #16 \n" + + "bne 0b \n" + + // prologue + "1:\n" + "fmla v16.4s, v8.4s, v0.s[0] \n" + "fmla v17.4s, v9.4s, v0.s[0] \n" + "fmla v18.4s, v8.4s, v0.s[1] \n" + "fmla v19.4s, v9.4s, v0.s[1] \n" + "fmla v20.4s, v8.4s, v0.s[2] \n" + "fmla v21.4s, v9.4s, v0.s[2] \n" + "fmla v22.4s, v8.4s, v0.s[3] \n" + "fmla v23.4s, v9.4s, v0.s[3] \n" + + "fmla v24.4s, v8.4s, v1.s[0] \n" + "fmla v25.4s, v9.4s, v1.s[0] \n" + "fmla v26.4s, v8.4s, v1.s[1] \n" + "fmla v27.4s, v9.4s, v1.s[1] \n" + "fmla v28.4s, v8.4s, v1.s[2] \n" + "fmla v29.4s, v9.4s, v1.s[2] \n" + "fmla v30.4s, v8.4s, v1.s[3] \n" + "fmla v31.4s, v9.4s, v1.s[3] \n" + + "fmla v16.4s, v10.4s, v2.s[0] \n" + "fmla v17.4s, v11.4s, v2.s[0] \n" + "fmla v18.4s, v10.4s, v2.s[1] \n" + "fmla v19.4s, v11.4s, v2.s[1] \n" + "fmla v20.4s, v10.4s, v2.s[2] \n" + "fmla v21.4s, v11.4s, v2.s[2] \n" + "fmla v22.4s, v10.4s, v2.s[3] \n" + "fmla v23.4s, v11.4s, v2.s[3] \n" + + "fmla v24.4s, v10.4s, v3.s[0] \n" + "fmla v25.4s, v11.4s, v3.s[0] \n" + "fmla v26.4s, v10.4s, v3.s[1] \n" + "fmla v27.4s, v11.4s, v3.s[1] \n" + "fmla v28.4s, v10.4s, v3.s[2] \n" + "fmla v29.4s, v11.4s, v3.s[2] \n" + "fmla v30.4s, v10.4s, v3.s[3] \n" + "fmla v31.4s, v11.4s, v3.s[3] \n" + + "fmla v16.4s, v12.4s, v4.s[0] \n" + "fmla v17.4s, v13.4s, v4.s[0] \n" + "fmla v18.4s, v12.4s, v4.s[1] \n" + "fmla v19.4s, v13.4s, v4.s[1] \n" + "fmla v20.4s, v12.4s, v4.s[2] \n" + "fmla v21.4s, v13.4s, v4.s[2] \n" + "fmla v22.4s, v12.4s, v4.s[3] \n" + "fmla v23.4s, v13.4s, v4.s[3] \n" + + "fmla v24.4s, v12.4s, v5.s[0] \n" + "fmla v25.4s, v13.4s, v5.s[0] \n" + "fmla v26.4s, v12.4s, v5.s[1] \n" + "fmla v27.4s, v13.4s, v5.s[1] \n" + "fmla v28.4s, v12.4s, v5.s[2] \n" + "fmla v29.4s, v13.4s, v5.s[2] \n" + "fmla v30.4s, v12.4s, v5.s[3] \n" + "fmla v31.4s, v13.4s, v5.s[3] \n" + + "fmla v16.4s, v14.4s, v6.s[0] \n" + "fmla v17.4s, v15.4s, v6.s[0] \n" + "fmla v18.4s, v14.4s, v6.s[1] \n" + "fmla v19.4s, v15.4s, v6.s[1] \n" + "fmla v20.4s, v14.4s, v6.s[2] \n" + "fmla v21.4s, v15.4s, v6.s[2] \n" + "fmla v22.4s, v14.4s, v6.s[3] \n" + "fmla v23.4s, v15.4s, v6.s[3] \n" + + "fmla v24.4s, v14.4s, v7.s[0] \n" + "fmla v25.4s, v15.4s, v7.s[0] \n" + "fmla v26.4s, v14.4s, v7.s[1] \n" + "fmla v27.4s, v15.4s, v7.s[1] \n" + "fmla v28.4s, v14.4s, v7.s[2] \n" + "fmla v29.4s, v15.4s, v7.s[2] \n" + "fmla v30.4s, v14.4s, v7.s[3] \n" + "fmla v31.4s, v15.4s, v7.s[3] \n" + + "shrn v16.4h, v16.4s, #16 \n" + "shrn v17.4h, v17.4s, #16 \n" + "shrn v18.4h, v18.4s, #16 \n" + "shrn v19.4h, v19.4s, #16 \n" + "st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [%[packed_output_data]], #32 \n" + + "shrn v20.4h, v20.4s, #16 \n" + "shrn v21.4h, v21.4s, #16 \n" + "shrn v22.4h, v22.4s, #16 \n" + "shrn v23.4h, v23.4s, #16 \n" + "st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [%[packed_output_data]], #32 \n" + + "shrn v24.4h, v24.4s, #16 \n" + "shrn v25.4h, v25.4s, #16 \n" + "shrn v26.4h, v26.4s, #16 \n" + "shrn v27.4h, v27.4s, #16 \n" + "st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [%[packed_output_data]], #32 \n" + + "shrn v28.4h, v28.4s, #16 \n" + "shrn v29.4h, v29.4s, #16 \n" + "shrn v30.4h, v30.4s, #16 \n" + "shrn v31.4h, v31.4s, #16 \n" + "st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [%[packed_output_data]], #32 \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [packed_output_data] "+r"(packed_output_data), + [r_depth_block_count] "+r"(r_depth_block_count) + : // inputs + : // clabbers + "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", + "v29", "v30", "v31"); + } +#else // armeabi-v7a + + // Register layout: (4x4) x (4,8) + // + // +--------+--------+ + // | q4 ... | q5 ... | + // Rhs +--------+--------+ + // | q6 ... | q7 ... | + // +--------+--------+ + // | q4 ... | q5 ... | + // +--------+--------+ + // | q6 ... | q7 ... | + // +--------+--------+ + // + // Lhs + // + // +----+----+----+----+ - - +--------+--------+ + // | q0 | q1 | q2 | q3 | | q8... | q9... | + // | . | | | | | q10... | q11... | + // | . | | | | | q12... | q13... | + // | . | | | | | q14... | q15... | + // +----+----+----+----+ +--------+--------+ + // + // Accumulator + // + + if (depth_block_count > 0) { + index_t r_depth_block_count = depth_block_count; + // just make compiler happy + MACE_UNUSED(r_depth_block_count); + + asm volatile( + "mov r0, #0\n" + "vdup.f32 q8, r0 \n" + "vdup.f32 q9, r0 \n" + "vdup.f32 q10, r0 \n" + "vdup.f32 q11, r0 \n" + "vdup.f32 q12, r0 \n" + "vdup.f32 q13, r0 \n" + "vdup.f32 q14, r0 \n" + "vdup.f32 q15, r0 \n" + + // prelogue + "vld1.u16 {d0-d3}, [%[lhs_ptr]]! \n" + "vshll.u16 q3, d3, #16 \n" + "vshll.u16 q2, d2, #16 \n" + "vshll.u16 q1, d1, #16 \n" + "vshll.u16 q0, d0, #16 \n" + + "vld1.u16 {d8-d11}, [%[rhs_ptr]]! \n" + "vshll.u16 q7, d11, #16 \n" + "vshll.u16 q6, d10, #16 \n" + "vshll.u16 q5, d9, #16 \n" + "vshll.u16 q4, d8, #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + "beq 1f\n" + + "0: \n" + + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q5, d0[0] \n" + "vmla.f32 q10, q4, d0[1] \n" + "vmla.f32 q11, q5, d0[1] \n" + "vmla.f32 q12, q4, d1[0] \n" + "vmla.f32 q13, q5, d1[0] \n" + "vmla.f32 q14, q4, d1[1] \n" + "vmla.f32 q15, q5, d1[1] \n" + + "vld1.u16 {d0}, [%[lhs_ptr]]! \n" + "vld1.u16 {d8-d9}, [%[rhs_ptr]]! \n" + "vshll.u16 q0, d0, #16 \n" + "vshll.u16 q5, d9, #16 \n" + "vshll.u16 q4, d8, #16 \n" + + "vmla.f32 q8, q6, d2[0] \n" + "vmla.f32 q9, q7, d2[0] \n" + "vmla.f32 q10, q6, d2[1] \n" + "vmla.f32 q11, q7, d2[1] \n" + "vmla.f32 q12, q6, d3[0] \n" + "vmla.f32 q13, q7, d3[0] \n" + "vmla.f32 q14, q6, d3[1] \n" + "vmla.f32 q15, q7, d3[1] \n" + + "vld1.u16 {d2}, [%[lhs_ptr]]! \n" + "vld1.u16 {d12-d13}, [%[rhs_ptr]]! \n" + "vshll.u16 q1, d2, #16 \n" + "vshll.u16 q7, d13, #16 \n" + "vshll.u16 q6, d12, #16 \n" + + "vmla.f32 q8, q4, d4[0] \n" + "vmla.f32 q9, q5, d4[0] \n" + "vmla.f32 q10, q4, d4[1] \n" + "vmla.f32 q11, q5, d4[1] \n" + "vmla.f32 q12, q4, d5[0] \n" + "vmla.f32 q13, q5, d5[0] \n" + "vmla.f32 q14, q4, d5[1] \n" + "vmla.f32 q15, q5, d5[1] \n" + + "vld1.u16 {d4}, [%[lhs_ptr]]! \n" + "vld1.u16 {d8-d9}, [%[rhs_ptr]]! \n" + "vshll.u16 q2, d4, #16 \n" + "vshll.u16 q5, d9, #16 \n" + "vshll.u16 q4, d8, #16 \n" + + "subs %[r_depth_block_count], %[r_depth_block_count], #1 \n" + + "vmla.f32 q8, q6, d6[0] \n" + "vmla.f32 q9, q7, d6[0] \n" + "vmla.f32 q10, q6, d6[1] \n" + "vmla.f32 q11, q7, d6[1] \n" + "vmla.f32 q12, q6, d7[0] \n" + "vmla.f32 q13, q7, d7[0] \n" + "vmla.f32 q14, q6, d7[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + + "vld1.u16 {d6}, [%[lhs_ptr]]! \n" + "vld1.u16 {d12-d13}, [%[rhs_ptr]]! \n" + "vshll.u16 q3, d6, #16 \n" + "vshll.u16 q7, d13, #16 \n" + "vshll.u16 q6, d12, #16 \n" + + "bne 0b \n" + + // prologue + "1:\n" + "vmla.f32 q8, q4, d0[0] \n" + "vmla.f32 q9, q5, d0[0] \n" + "vmla.f32 q10, q4, d0[1] \n" + "vmla.f32 q11, q5, d0[1] \n" + "vmla.f32 q12, q4, d1[0] \n" + "vmla.f32 q13, q5, d1[0] \n" + "vmla.f32 q14, q4, d1[1] \n" + "vmla.f32 q15, q5, d1[1] \n" + + "vld1.u16 {d8-d9}, [%[rhs_ptr]]! \n" + "vshll.u16 q5, d9, #16 \n" + "vshll.u16 q4, d8, #16 \n" + + "vmla.f32 q8, q6, d2[0] \n" + "vmla.f32 q9, q7, d2[0] \n" + "vmla.f32 q10, q6, d2[1] \n" + "vmla.f32 q11, q7, d2[1] \n" + "vmla.f32 q12, q6, d3[0] \n" + "vmla.f32 q13, q7, d3[0] \n" + "vmla.f32 q14, q6, d3[1] \n" + "vmla.f32 q15, q7, d3[1] \n" + + "vld1.u16 {d12-d13}, [%[rhs_ptr]]! \n" + "vshll.u16 q7, d13, #16 \n" + "vshll.u16 q6, d12, #16 \n" + + "vmla.f32 q8, q4, d4[0] \n" + "vmla.f32 q9, q5, d4[0] \n" + "vmla.f32 q10, q4, d4[1] \n" + "vmla.f32 q11, q5, d4[1] \n" + "vmla.f32 q12, q4, d5[0] \n" + "vmla.f32 q13, q5, d5[0] \n" + "vmla.f32 q14, q4, d5[1] \n" + "vmla.f32 q15, q5, d5[1] \n" + + "vmla.f32 q8, q6, d6[0] \n" + "vmla.f32 q9, q7, d6[0] \n" + "vmla.f32 q10, q6, d6[1] \n" + "vmla.f32 q11, q7, d6[1] \n" + "vmla.f32 q12, q6, d7[0] \n" + "vmla.f32 q13, q7, d7[0] \n" + "vmla.f32 q14, q6, d7[1] \n" + "vmla.f32 q15, q7, d7[1] \n" + + "vshrn.u32 d16, q8, #16 \n" + "vshrn.u32 d17, q9, #16 \n" + "vst1.u16 {d16-d17}, [%[packed_output_data]]! \n" + "vshrn.u32 d20, q10, #16 \n" + "vshrn.u32 d21, q11, #16 \n" + "vst1.u16 {d20-d21}, [%[packed_output_data]]! \n" + "vshrn.u32 d24, q12, #16 \n" + "vshrn.u32 d25, q13, #16 \n" + "vst1.u16 {d24-d25}, [%[packed_output_data]]! \n" + "vshrn.u32 d28, q14, #16 \n" + "vshrn.u32 d29, q15, #16 \n" + "vst1.u16 {d28-d29}, [%[packed_output_data]]! \n" + : // outputs + [lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr), + [packed_output_data] "+r"(packed_output_data), + [r_depth_block_count] "+r"(r_depth_block_count) + : // inputs + : // clabbers + "cc", "memory", "r0", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", + "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); + } +#endif +} + +} // namespace arm +} // namespace ops +} // namespace mace diff --git a/mace/ops/arm/fp32/activation.cc b/mace/ops/arm/fp32/activation.cc deleted file mode 100644 index 4f3daa0f..00000000 --- a/mace/ops/arm/fp32/activation.cc +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2019 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include - -#include "mace/ops/arm/base/activation.h" - -namespace mace { -namespace ops { -namespace arm { - -template<> -void Activation::ActivateRelu(utils::ThreadPool *thread_pool, - const Tensor *input, - Tensor *output) { - const auto input_data = input->data(); - auto output_data = output->mutable_data(); - const index_t input_size = input->size(); - const float32x4_t vzero = vdupq_n_f32(0.f); - const index_t block_count = input_size / 4; - - thread_pool->Compute1D( - [=](index_t start, index_t end, index_t step) { - auto input_ptr = input_data + start * 4; - auto output_ptr = output_data + start * 4; - - for (index_t i = start; i < end; i += step) { - float32x4_t v = vld1q_f32(input_ptr); - v = vmaxq_f32(v, vzero); - vst1q_f32(output_ptr, v); - - input_ptr += 4; - output_ptr += 4; - } - }, - 0, block_count, 1); - - // remain - for (index_t i = block_count * 4; i < input_size; ++i) { - output_data[i] = std::max(0.f, input_data[i]); - } -} - -template<> -void Activation::ActivateRelux(utils::ThreadPool *thread_pool, - const Tensor *input, - Tensor *output) { - const auto input_data = input->data(); - auto output_data = output->mutable_data(); - const index_t input_size = input->size(); - const float32x4_t vzero = vdupq_n_f32(0.f); - const float32x4_t vlimit = vdupq_n_f32(limit_); - const index_t block_count = input_size / 4; - - thread_pool->Compute1D( - [=](index_t start, index_t end, index_t step) { - auto input_ptr = input_data + start * 4; - auto output_ptr = output_data + start * 4; - - for (index_t i = start; i < end; i += step) { - float32x4_t v = vld1q_f32(input_ptr); - v = vmaxq_f32(v, vzero); - v = vminq_f32(v, vlimit); - vst1q_f32(output_ptr, v); - - input_ptr += 4; - output_ptr += 4; - } - }, - 0, block_count, 1); - - // remain - for (index_t i = block_count * 4; i < input_size; ++i) { - output_data[i] = std::max(0.f, std::min(limit_, input_data[i])); - } -} - -template<> -void Activation::ActivateLeakyRelu(utils::ThreadPool *thread_pool, - const Tensor *input, - Tensor *output) { - const auto input_data = input->data(); - auto output_data = output->mutable_data(); - const index_t input_size = input->size(); - const float32x4_t vzero = vdupq_n_f32(0.f); - const float32x4_t valpha = vdupq_n_f32(leakyrelu_coefficient_); - const index_t block_count = input_size / 4; - - thread_pool->Compute1D( - [=](index_t start, index_t end, index_t step) { - auto input_ptr = input_data + start * 4; - auto output_ptr = output_data + start * 4; - - for (index_t i = start; i < end; i += step) { - float32x4_t v = vld1q_f32(input_ptr); - float32x4_t u = vminq_f32(v, vzero); - v = vmaxq_f32(v, vzero); - v = vmlaq_f32(v, valpha, u); - vst1q_f32(output_ptr, v); - - input_ptr += 4; - output_ptr += 4; - } - }, - 0, block_count, 1); - - // remain - for (index_t i = block_count * 4; i < input_size; ++i) { - output_data[i] = std::max(input_data[i], 0.f) + - std::min(input_data[i], 0.f) * leakyrelu_coefficient_; - } -} - -template<> -void Activation::ActivateTanh(utils::ThreadPool *thread_pool, - const Tensor *input, - Tensor *output) { - const auto input_data = input->data(); - auto output_data = output->mutable_data(); - const index_t input_size = input->size(); - - thread_pool->Compute1D( - [=](index_t start, index_t end, index_t step) { - for (index_t i = start; i < end; i += step) { - output_data[i] = std::tanh(input_data[i]); - } - }, - 0, input_size, 1); -} - -template<> -void Activation::ActivateSigmoid(utils::ThreadPool *thread_pool, - const Tensor *input, - Tensor *output) { - const auto input_data = input->data(); - auto output_data = output->mutable_data(); - const index_t input_size = input->size(); - - thread_pool->Compute1D( - [=](index_t start, index_t end, index_t step) { - for (index_t i = start; i < end; i += step) { - output_data[i] = 1 / (1 + std::exp(-(input_data[i]))); - } - }, - 0, input_size, 1); -} - -} // namespace arm -} // namespace ops -} // namespace mace diff --git a/mace/ops/arm/fp32/bias_add.cc b/mace/ops/arm/fp32/bias_add.cc deleted file mode 100644 index 2c0d8326..00000000 --- a/mace/ops/arm/fp32/bias_add.cc +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2019 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/ops/arm/base/bias_add.h" - -#include - -namespace mace { -namespace ops { -namespace arm { - -template <> -template -void BiasAdd::AddBiasNCHW(utils::ThreadPool *thread_pool, - const Tensor *input, - const Tensor *bias, - Tensor *output) { - auto input_data = input->data(); - auto bias_data = bias->data(); - auto output_data = output->mutable_data(); - - const index_t batch = input->dim(0); - const index_t channels = input->dim(1); - const index_t image_size = input->dim(2) * input->dim(3); - const index_t block_count = image_size / 4; - const index_t remain = image_size % 4; - thread_pool->Compute2D( - [=](index_t start0, index_t end0, index_t step0, index_t start1, - index_t end1, index_t step1) { - for (index_t b = start0; b < end0; b += step0) { - const index_t b_offset = b * channels; - for (index_t c = start1; c < end1; c += step1) { - const index_t offset = (b_offset + c) * image_size; - auto input_ptr = input_data + offset; - auto output_ptr = output_data + offset; - const float bias = bias_data[bias_index(b_offset, c)]; - float32x4_t vbias = vdupq_n_f32(bias); - - for (index_t i = 0; i < block_count; ++i) { - float32x4_t v = vld1q_f32(input_ptr); - v = vaddq_f32(v, vbias); - vst1q_f32(output_ptr, v); - - input_ptr += 4; - output_ptr += 4; - } - for (index_t i = 0; i < remain; ++i) { - (*output_ptr++) = (*input_ptr++) + bias; - } - } - } - }, - 0, batch, 1, 0, channels, 1); -} - -template <> -template -void BiasAdd::AddBiasNHWC(utils::ThreadPool *thread_pool, - const Tensor *input, - const Tensor *bias, - Tensor *output) { - const float *input_ptr = input->data(); - const float *bias_ptr = bias->data(); - float *output_ptr = output->mutable_data(); - - const std::vector &shape = input->shape(); - const index_t channels = *shape.rbegin(); - const auto batch = shape[0]; - if (Dim == 2) { - MACE_CHECK(batch == bias->shape()[0]); - } - const index_t fused_hw = std::accumulate(shape.begin() + 1, shape.end() - 1, - 1, std::multiplies()); - thread_pool->Compute2D( - [=](index_t start0, index_t end0, index_t step0, index_t start1, - index_t end1, index_t step1) { - for (index_t i = start0; i < end0; i += step0) { - auto offset = i * fused_hw; - auto bias_offset = i * channels; - for (index_t j = start1; j < end1; j += step1) { - index_t pos = (offset + j) * channels; - for (index_t c = 0; c < channels; ++c, ++pos) { - output_ptr[pos] = - input_ptr[pos] + bias_ptr[bias_index(bias_offset, c)]; - } - } - } - }, - 0, batch, 1, 0, fused_hw, 1); -} - -template void BiasAdd::AddBiasNCHW<1>(utils::ThreadPool *thread_pool, - const Tensor *input, - const Tensor *bias, - Tensor *output); -template void BiasAdd::AddBiasNCHW<2>(utils::ThreadPool *thread_pool, - const Tensor *input, - const Tensor *bias, - Tensor *output); - -template void BiasAdd::AddBiasNHWC<1>(utils::ThreadPool *thread_pool, - const Tensor *input, - const Tensor *bias, - Tensor *output); -template void BiasAdd::AddBiasNHWC<2>(utils::ThreadPool *thread_pool, - const Tensor *input, - const Tensor *bias, - Tensor *output); - -} // namespace arm -} // namespace ops -} // namespace mace diff --git a/mace/ops/arm/fp32/common_neon.h b/mace/ops/arm/fp32/common_neon.h deleted file mode 100644 index 502ffc39..00000000 --- a/mace/ops/arm/fp32/common_neon.h +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2019 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_OPS_ARM_FP32_COMMON_NEON_H_ -#define MACE_OPS_ARM_FP32_COMMON_NEON_H_ - -#if defined(MACE_ENABLE_NEON) -#include - -namespace mace { -namespace ops { -namespace arm { - -inline float32x4_t neon_vfma_lane_0(float32x4_t a, - float32x4_t b, - float32x4_t c) { -#ifdef __aarch64__ - return vfmaq_laneq_f32(a, b, c, 0); -#else - return vmlaq_lane_f32(a, b, vget_low_f32(c), 0); -#endif -} - -inline float32x4_t neon_vfma_lane_1(float32x4_t a, - float32x4_t b, - float32x4_t c) { -#ifdef __aarch64__ - return vfmaq_laneq_f32(a, b, c, 1); -#else - return vmlaq_lane_f32(a, b, vget_low_f32(c), 1); -#endif -} - -inline float32x4_t neon_vfma_lane_2(float32x4_t a, - float32x4_t b, - float32x4_t c) { -#ifdef __aarch64__ - return vfmaq_laneq_f32(a, b, c, 2); -#else - return vmlaq_lane_f32(a, b, vget_high_f32(c), 0); -#endif -} - -inline float32x4_t neon_vfma_lane_3(float32x4_t a, - float32x4_t b, - float32x4_t c) { -#ifdef __aarch64__ - return vfmaq_laneq_f32(a, b, c, 3); -#else - return vmlaq_lane_f32(a, b, vget_high_f32(c), 1); -#endif -} - -} // namespace arm -} // namespace ops -} // namespace mace -#endif // MACE_ENABLE_NEON - -#endif // MACE_OPS_ARM_FP32_COMMON_NEON_H_ diff --git a/mace/ops/arm/fp32/conv_2d_general.cc b/mace/ops/arm/fp32/conv_2d_general.cc deleted file mode 100644 index 6f6a1ff5..00000000 --- a/mace/ops/arm/fp32/conv_2d_general.cc +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2019 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "mace/ops/arm/base/conv_2d_general.h" -#include "mace/ops/delegator/conv_2d.h" - -namespace mace { -namespace ops { -namespace arm { - -template<> -MaceStatus Conv2dGeneral::DoCompute( - const ConvComputeParam &p, const float *filter_data, - const float *input_data, float *output_data, - const std::vector &filter_shape) { - const index_t filter_height = filter_shape[2]; - const index_t filter_width = filter_shape[3]; - const index_t filter_size = filter_height * filter_width; - - p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, - index_t start1, index_t end1, index_t step1) { - for (index_t b = start0; b < end0; b += step0) { - for (index_t m = start1; m < end1; m += step1) { - const int stride_h = strides_[0]; - const int stride_w = strides_[1]; - const int dilation_h = dilations_[0]; - const int dilation_w = dilations_[1]; - if (m + 3 < p.out_channels) { - float *out_ptr0_base = - output_data + b * p.out_batch_size + m * p.out_image_size; - float *out_ptr1_base = out_ptr0_base + p.out_image_size; - float *out_ptr2_base = out_ptr1_base + p.out_image_size; - float *out_ptr3_base = out_ptr2_base + p.out_image_size; - for (index_t c = 0; c < p.in_channels; ++c) { - const float *in_ptr_base = - input_data + b * p.in_batch_size + c * p.in_image_size; - const float *filter_ptr0 = - filter_data + m * p.in_channels * filter_size + c * filter_size; - const float *filter_ptr1 = - filter_ptr0 + p.in_channels * filter_size; - const float *filter_ptr2 = - filter_ptr1 + p.in_channels * filter_size; - const float *filter_ptr3 = - filter_ptr2 + p.in_channels * filter_size; - for (index_t h = 0; h < p.out_height; ++h) { - for (index_t w = 0; w + 3 < p.out_width; w += 4) { - // input offset - index_t ih = h * stride_h; - index_t iw = w * stride_w; - index_t in_offset = ih * p.in_width + iw; - // output (4 outch x 1 height x 4 width): vo_outch_height - float vo0[4], vo1[4], vo2[4], vo3[4]; - // load output - index_t out_offset = h * p.out_width + w; - for (index_t ow = 0; ow < 4; ++ow) { - vo0[ow] = out_ptr0_base[out_offset + ow]; - vo1[ow] = out_ptr1_base[out_offset + ow]; - vo2[ow] = out_ptr2_base[out_offset + ow]; - vo3[ow] = out_ptr3_base[out_offset + ow]; - } - // calc by row - for (index_t kh = 0; kh < filter_height; ++kh) { - for (index_t kw = 0; kw < filter_width; ++kw) { - // outch 0 - vo0[0] += in_ptr_base[in_offset - + kw * dilation_w] * filter_ptr0[kw]; - vo0[1] += in_ptr_base[in_offset + stride_w - + kw * dilation_w] * filter_ptr0[kw]; - vo0[2] += in_ptr_base[in_offset + 2 * stride_w - + kw * dilation_w] * filter_ptr0[kw]; - vo0[3] += in_ptr_base[in_offset + 3 * stride_w - + kw * dilation_w] * filter_ptr0[kw]; - // outch 1 - vo1[0] += in_ptr_base[in_offset - + kw * dilation_w] * filter_ptr1[kw]; - vo1[1] += in_ptr_base[in_offset + stride_w - + kw * dilation_w] * filter_ptr1[kw]; - vo1[2] += in_ptr_base[in_offset + 2 * stride_w - + kw * dilation_w] * filter_ptr1[kw]; - vo1[3] += in_ptr_base[in_offset + 3 * stride_w - + kw * dilation_w] * filter_ptr1[kw]; - // outch 2 - vo2[0] += in_ptr_base[in_offset - + kw * dilation_w] * filter_ptr2[kw]; - vo2[1] += in_ptr_base[in_offset + stride_w - + kw * dilation_w] * filter_ptr2[kw]; - vo2[2] += in_ptr_base[in_offset + 2 * stride_w - + kw * dilation_w] * filter_ptr2[kw]; - vo2[3] += in_ptr_base[in_offset + 3 * stride_w - + kw * dilation_w] * filter_ptr2[kw]; - // outch 3 - vo3[0] += in_ptr_base[in_offset - + kw * dilation_w] * filter_ptr3[kw]; - vo3[1] += in_ptr_base[in_offset + stride_w - + kw * dilation_w] * filter_ptr3[kw]; - vo3[2] += in_ptr_base[in_offset + 2 * stride_w - + kw * dilation_w] * filter_ptr3[kw]; - vo3[3] += in_ptr_base[in_offset + 3 * stride_w - + kw * dilation_w] * filter_ptr3[kw]; - } // kw - - in_offset += dilation_h * p.in_width; - filter_ptr0 += filter_width; - filter_ptr1 += filter_width; - filter_ptr2 += filter_width; - filter_ptr3 += filter_width; - } // kh - - for (index_t ow = 0; ow < 4; ++ow) { - out_ptr0_base[out_offset + ow] = vo0[ow]; - out_ptr1_base[out_offset + ow] = vo1[ow]; - out_ptr2_base[out_offset + ow] = vo2[ow]; - out_ptr3_base[out_offset + ow] = vo3[ow]; - } - - filter_ptr0 -= filter_size; - filter_ptr1 -= filter_size; - filter_ptr2 -= filter_size; - filter_ptr3 -= filter_size; - } // w - } // h - } // c - } else { - for (index_t mm = m; mm < p.out_channels; ++mm) { - float *out_ptr0_base = - output_data + b * p.out_batch_size + mm * p.out_image_size; - for (index_t c = 0; c < p.in_channels; ++c) { - const float *in_ptr_base = - input_data + b * p.in_batch_size + c * p.in_image_size; - const float *filter_ptr0 = - filter_data + mm * p.in_channels * filter_size - + c * filter_size; - - for (index_t h = 0; h < p.out_height; ++h) { - for (index_t w = 0; w + 3 < p.out_width; w += 4) { - // input offset - index_t ih = h * stride_h; - index_t iw = w * stride_w; - index_t in_offset = ih * p.in_width + iw; - // output (1 outch x 1 height x 4 width): vo_outch_height - float vo0[4]; - // load output - index_t out_offset = h * p.out_width + w; - for (index_t ow = 0; ow < 4; ++ow) { - vo0[ow] = out_ptr0_base[out_offset + ow]; - } - - // calc by row - for (index_t kh = 0; kh < filter_height; ++kh) { - for (index_t kw = 0; kw < filter_width; ++kw) { - // outch 0 - vo0[0] += in_ptr_base[in_offset - + kw * dilation_w] * filter_ptr0[kw]; - vo0[1] += in_ptr_base[in_offset + stride_w - + kw * dilation_w] * filter_ptr0[kw]; - vo0[2] += in_ptr_base[in_offset + 2 * stride_w - + kw * dilation_w] * filter_ptr0[kw]; - vo0[3] += in_ptr_base[in_offset + 3 * stride_w - + kw * dilation_w] * filter_ptr0[kw]; - } // kw - - in_offset += dilation_h * p.in_width; - filter_ptr0 += filter_width; - } // kh - - for (index_t ow = 0; ow < 4; ++ow) { - out_ptr0_base[out_offset + ow] = vo0[ow]; - } - filter_ptr0 -= filter_size; - } // w - } // h - } // c - } // mm - } // if - } // m - } // b - }, 0, p.batch, 1, 0, p.out_channels, 4); - - return MaceStatus::MACE_SUCCESS; -} - -} // namespace arm -} // namespace ops -} // namespace mace diff --git a/mace/ops/arm/fp32/deconv_2d_2x2.cc b/mace/ops/arm/fp32/deconv_2d_2x2.cc index 2a6ca40d..46f41367 100644 --- a/mace/ops/arm/fp32/deconv_2d_2x2.cc +++ b/mace/ops/arm/fp32/deconv_2d_2x2.cc @@ -14,8 +14,8 @@ #include +#include "mace/ops/arm/base/common_neon.h" #include "mace/ops/arm/base/deconv_2d_2x2.h" -#include "mace/ops/arm/fp32/common_neon.h" namespace mace { namespace ops { diff --git a/mace/ops/arm/fp32/deconv_2d_3x3.cc b/mace/ops/arm/fp32/deconv_2d_3x3.cc index 4c00f07d..cce6f02d 100644 --- a/mace/ops/arm/fp32/deconv_2d_3x3.cc +++ b/mace/ops/arm/fp32/deconv_2d_3x3.cc @@ -14,8 +14,8 @@ #include +#include "mace/ops/arm/base/common_neon.h" #include "mace/ops/arm/base/deconv_2d_3x3.h" -#include "mace/ops/arm/fp32/common_neon.h" namespace mace { namespace ops { diff --git a/mace/ops/arm/fp32/deconv_2d_4x4.cc b/mace/ops/arm/fp32/deconv_2d_4x4.cc index 2dbe4d3e..434016af 100644 --- a/mace/ops/arm/fp32/deconv_2d_4x4.cc +++ b/mace/ops/arm/fp32/deconv_2d_4x4.cc @@ -14,8 +14,8 @@ #include +#include "mace/ops/arm/base/common_neon.h" #include "mace/ops/arm/base/deconv_2d_4x4.h" -#include "mace/ops/arm/fp32/common_neon.h" namespace mace { namespace ops { diff --git a/mace/ops/arm/fp32/depthwise_conv_2d_3x3.cc b/mace/ops/arm/fp32/depthwise_conv_2d_3x3.cc deleted file mode 100644 index fa850e56..00000000 --- a/mace/ops/arm/fp32/depthwise_conv_2d_3x3.cc +++ /dev/null @@ -1,428 +0,0 @@ -// Copyright 2019 The MACE Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "mace/ops/arm/base/depthwise_conv_2d_3x3.h" - -namespace mace { -namespace ops { -namespace arm { - -namespace { -void DepthwiseConv2dPixel(const float *in_base, - const float *filter, - const index_t out_h, - const index_t out_w, - const index_t in_h_start, - const index_t in_w_start, - const index_t out_width, - const index_t in_height, - const index_t in_width, - int filter_height, - int filter_width, - float *out_base) { - float sum = 0; - for (int i = 0; i < filter_height; ++i) { - for (int j = 0; j < filter_width; ++j) { - index_t in_h = in_h_start + i; - index_t in_w = in_w_start + j; - if (in_h >= 0 && in_h < in_height && in_w >= 0 && in_w < in_width) { - sum += in_base[in_h * in_width + in_w] * filter[i * filter_width + j]; - } - } - } - out_base[out_h * out_width + out_w] = sum; -} -} // namespace - -template<> -MaceStatus DepthwiseConv2dK3x3S1::DoCompute( - const DepthwiseConvComputeParam &p, const float *filter_data, - const float *input_data, float *output_data) { - p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, - index_t start1, index_t end1, index_t step1) { - for (index_t b = start0; b < end0; b += step0) { - for (index_t m = start1; m < end1; m += step1) { - const index_t c = m / p.multiplier; - const index_t multi_index = m % p.multiplier; - const float - *in_base = input_data + b * p.in_batch_size + c * p.in_image_size; - const float - *filter_ptr = filter_data + multi_index * p.in_channels * 9 + c * 9; - float *out_base = - output_data + b * p.out_batch_size + m * p.out_image_size; - index_t h, w; - - // top - for (h = 0; h < p.valid_h_start; ++h) { - for (w = 0; w < p.out_width; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h - p.pad_top, - w - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - } - - // load filter (1 outch x 3 height x 3 width): vf_outch_height - float32x4_t vf00, vf01, vf02; - vf00 = vld1q_f32(filter_ptr); - vf01 = vld1q_f32(filter_ptr + 3); - vf02 = vld1q_f32(filter_ptr + 5); - - for (h = p.valid_h_start; h + 1 < p.valid_h_stop; h += 2) { - // left - for (w = 0; w < p.valid_w_start; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h - p.pad_top, - w - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - DepthwiseConv2dPixel(in_base, - filter_ptr, - h + 1, - w, - h + 1 - p.pad_top, - w - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - - for (w = p.valid_w_start; w + 3 < p.valid_w_stop; w += 4) { - // input (4 height x 3 slide): vi_height_slide - float32x4_t vi00, vi01, vi02, vi0n; - float32x4_t vi10, vi11, vi12, vi1n; - float32x4_t vi20, vi21, vi22, vi2n; - float32x4_t vi30, vi31, vi32, vi3n; - - // output (1 outch x 2 height x 4 width): vo_outch_height - float32x4_t vo00, vo01; - - // load input - index_t in_h = h - p.pad_top; - index_t in_w = w - p.pad_left; - index_t in_offset = in_h * p.in_width + in_w; - vi00 = vld1q_f32(in_base + in_offset); - vi0n = vld1q_f32(in_base + in_offset + 4); - vi10 = vld1q_f32(in_base + in_offset + p.in_width); - vi1n = vld1q_f32(in_base + in_offset + p.in_width + 4); - vi20 = vld1q_f32(in_base + in_offset + 2 * p.in_width); - vi2n = vld1q_f32(in_base + in_offset + 2 * p.in_width + 4); - vi30 = vld1q_f32(in_base + in_offset + 3 * p.in_width); - vi3n = vld1q_f32(in_base + in_offset + 3 * p.in_width + 4); - - vi01 = vextq_f32(vi00, vi0n, 1); - vi02 = vextq_f32(vi00, vi0n, 2); - vi11 = vextq_f32(vi10, vi1n, 1); - vi12 = vextq_f32(vi10, vi1n, 2); - vi21 = vextq_f32(vi20, vi2n, 1); - vi22 = vextq_f32(vi20, vi2n, 2); - vi31 = vextq_f32(vi30, vi3n, 1); - vi32 = vextq_f32(vi30, vi3n, 2); - - // load ouptut - index_t out_offset = h * p.out_width + w; - vo00 = vld1q_f32(out_base + out_offset); - vo01 = vld1q_f32(out_base + out_offset + p.out_width); - -#if defined(__aarch64__) - // outch 0, height 0 - vo00 = vfmaq_laneq_f32(vo00, vi00, vf00, 0); - vo00 = vfmaq_laneq_f32(vo00, vi01, vf00, 1); - vo00 = vfmaq_laneq_f32(vo00, vi02, vf00, 2); - vo00 = vfmaq_laneq_f32(vo00, vi10, vf01, 0); - vo00 = vfmaq_laneq_f32(vo00, vi11, vf01, 1); - vo00 = vfmaq_laneq_f32(vo00, vi12, vf01, 2); - vo00 = vfmaq_laneq_f32(vo00, vi20, vf02, 1); - vo00 = vfmaq_laneq_f32(vo00, vi21, vf02, 2); - vo00 = vfmaq_laneq_f32(vo00, vi22, vf02, 3); - - // outch 0, height 1 - vo01 = vfmaq_laneq_f32(vo01, vi10, vf00, 0); - vo01 = vfmaq_laneq_f32(vo01, vi11, vf00, 1); - vo01 = vfmaq_laneq_f32(vo01, vi12, vf00, 2); - vo01 = vfmaq_laneq_f32(vo01, vi20, vf01, 0); - vo01 = vfmaq_laneq_f32(vo01, vi21, vf01, 1); - vo01 = vfmaq_laneq_f32(vo01, vi22, vf01, 2); - vo01 = vfmaq_laneq_f32(vo01, vi30, vf02, 1); - vo01 = vfmaq_laneq_f32(vo01, vi31, vf02, 2); - vo01 = vfmaq_laneq_f32(vo01, vi32, vf02, 3); -#else - // outch 0, height 0 - vo00 = vmlaq_lane_f32(vo00, vi00, vget_low_f32(vf00), 0); - vo00 = vmlaq_lane_f32(vo00, vi01, vget_low_f32(vf00), 1); - vo00 = vmlaq_lane_f32(vo00, vi02, vget_high_f32(vf00), 0); - vo00 = vmlaq_lane_f32(vo00, vi10, vget_low_f32(vf01), 0); - vo00 = vmlaq_lane_f32(vo00, vi11, vget_low_f32(vf01), 1); - vo00 = vmlaq_lane_f32(vo00, vi12, vget_high_f32(vf01), 0); - vo00 = vmlaq_lane_f32(vo00, vi20, vget_low_f32(vf02), 1); - vo00 = vmlaq_lane_f32(vo00, vi21, vget_high_f32(vf02), 0); - vo00 = vmlaq_lane_f32(vo00, vi22, vget_high_f32(vf02), 1); - - // outch 0, height 1 - vo01 = vmlaq_lane_f32(vo01, vi10, vget_low_f32(vf00), 0); - vo01 = vmlaq_lane_f32(vo01, vi11, vget_low_f32(vf00), 1); - vo01 = vmlaq_lane_f32(vo01, vi12, vget_high_f32(vf00), 0); - vo01 = vmlaq_lane_f32(vo01, vi20, vget_low_f32(vf01), 0); - vo01 = vmlaq_lane_f32(vo01, vi21, vget_low_f32(vf01), 1); - vo01 = vmlaq_lane_f32(vo01, vi22, vget_high_f32(vf01), 0); - vo01 = vmlaq_lane_f32(vo01, vi30, vget_low_f32(vf02), 1); - vo01 = vmlaq_lane_f32(vo01, vi31, vget_high_f32(vf02), 0); - vo01 = vmlaq_lane_f32(vo01, vi32, vget_high_f32(vf02), 1); -#endif - vst1q_f32(out_base + out_offset, vo00); - vst1q_f32(out_base + out_offset + p.out_width, vo01); - } // w - - // right - for (; w < p.out_width; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h - p.pad_top, - w - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - DepthwiseConv2dPixel(in_base, - filter_ptr, - h + 1, - w, - h + 1 - p.pad_top, - w - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - } // h - - - // bottom - for (; h < p.out_height; ++h) { - for (w = 0; w < p.out_width; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h - p.pad_top, - w - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - } - } // m - } // b - }, 0, p.batch, 1, 0, p.out_channels, 1); // threadpool - - return MaceStatus::MACE_SUCCESS; -} - -template<> -MaceStatus DepthwiseConv2dK3x3S2::DoCompute( - const DepthwiseConvComputeParam &p, const float *filter_data, - const float *input_data, float *output_data) { - p.thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0, - index_t start1, index_t end1, index_t step1) { - for (index_t b = start0; b < end0; b += step0) { - for (index_t m = start1; m < end1; m += step1) { - index_t c = m / p.multiplier; - index_t multi_index = m % p.multiplier; - const float - *in_base = input_data + b * p.in_batch_size + c * p.in_image_size; - const float - *filter_ptr = filter_data + multi_index * p.in_channels * 9 + c * 9; - float *out_base = - output_data + b * p.out_batch_size + m * p.out_image_size; - index_t h, w; - - // top - for (h = 0; h < p.valid_h_start; ++h) { - for (w = 0; w < p.out_width; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h * 2 - p.pad_top, - w * 2 - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - } - - // load filter (1 outch x 3 height x 3 width): vf_outch_height - float32x4_t vf00, vf01, vf02; - vf00 = vld1q_f32(filter_ptr); - vf01 = vld1q_f32(filter_ptr + 3); - vf02 = vld1q_f32(filter_ptr + 5); - - for (h = p.valid_h_start; h < p.valid_h_stop; ++h) { - // left - for (w = 0; w < p.valid_w_start; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h * 2 - p.pad_top, - w * 2 - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - - for (w = p.valid_w_start; w + 3 < p.valid_w_stop; w += 4) { - float32x4x2_t vi0, vi1, vi2; - float32x4_t vi0n, vi1n, vi2n; - - // input (3 height x 3 slide): vi_height_slide - float32x4_t vi00, vi01, vi02; - float32x4_t vi10, vi11, vi12; - float32x4_t vi20, vi21, vi22; - - // output (1 outch x 1 height x 4 width): vo - float32x4_t vo; - - // load input - index_t in_h = h * 2 - p.pad_top; - index_t in_w = w * 2 - p.pad_left; - index_t in_offset = in_h * p.in_width + in_w; - vi0 = vld2q_f32(in_base + in_offset); // [0.2.4.6, 1.3.5.7] - vi1 = vld2q_f32(in_base + in_offset + p.in_width); - vi2 = vld2q_f32(in_base + in_offset + 2 * p.in_width); - - vi0n = vld1q_f32(in_base + in_offset + 8); // [8.9.10.11] - vi1n = vld1q_f32(in_base + in_offset + p.in_width + 8); - vi2n = vld1q_f32(in_base + in_offset + 2 * p.in_width + 8); - - // load ouptut - index_t out_offset = h * p.out_width + w; - vo = vld1q_f32(out_base + out_offset); - - vi00 = vi0.val[0]; // [0.2.4.6] - vi01 = vi0.val[1]; // [1.3.5.7] - vi02 = vextq_f32(vi00, vi0n, 1); // [2.4.6.8] - vi10 = vi1.val[0]; - vi11 = vi1.val[1]; - vi12 = vextq_f32(vi10, vi1n, 1); - vi20 = vi2.val[0]; - vi21 = vi2.val[1]; - vi22 = vextq_f32(vi20, vi2n, 1); - -#if defined(__aarch64__) - // outch 0, height 0 - vo = vfmaq_laneq_f32(vo, vi00, vf00, 0); - vo = vfmaq_laneq_f32(vo, vi01, vf00, 1); - vo = vfmaq_laneq_f32(vo, vi02, vf00, 2); - vo = vfmaq_laneq_f32(vo, vi10, vf01, 0); - vo = vfmaq_laneq_f32(vo, vi11, vf01, 1); - vo = vfmaq_laneq_f32(vo, vi12, vf01, 2); - vo = vfmaq_laneq_f32(vo, vi20, vf02, 1); - vo = vfmaq_laneq_f32(vo, vi21, vf02, 2); - vo = vfmaq_laneq_f32(vo, vi22, vf02, 3); -#else - // outch 0, height 0 - vo = vmlaq_lane_f32(vo, vi00, vget_low_f32(vf00), 0); - vo = vmlaq_lane_f32(vo, vi01, vget_low_f32(vf00), 1); - vo = vmlaq_lane_f32(vo, vi02, vget_high_f32(vf00), 0); - vo = vmlaq_lane_f32(vo, vi10, vget_low_f32(vf01), 0); - vo = vmlaq_lane_f32(vo, vi11, vget_low_f32(vf01), 1); - vo = vmlaq_lane_f32(vo, vi12, vget_high_f32(vf01), 0); - vo = vmlaq_lane_f32(vo, vi20, vget_low_f32(vf02), 1); - vo = vmlaq_lane_f32(vo, vi21, vget_high_f32(vf02), 0); - vo = vmlaq_lane_f32(vo, vi22, vget_high_f32(vf02), 1); -#endif - vst1q_f32(out_base + out_offset, vo); - } // w - - // right - for (; w < p.out_width; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h * 2 - p.pad_top, - w * 2 - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - } // h - - - // bottom - for (; h < p.out_height; ++h) { - for (w = 0; w < p.out_width; ++w) { - DepthwiseConv2dPixel(in_base, - filter_ptr, - h, - w, - h * 2 - p.pad_top, - w * 2 - p.pad_left, - p.out_width, - p.in_height, - p.in_width, - 3, - 3, - out_base); - } - } - } // m - } // b - }, 0, p.batch, 1, 0, p.out_channels, 1); - - return MaceStatus::MACE_SUCCESS; -} - -} // namespace arm -} // namespace ops -} // namespace mace diff --git a/mace/ops/arm/fp32/depthwise_deconv_2d_3x3.cc b/mace/ops/arm/fp32/depthwise_deconv_2d_3x3.cc index 99e9c9eb..71b7934f 100644 --- a/mace/ops/arm/fp32/depthwise_deconv_2d_3x3.cc +++ b/mace/ops/arm/fp32/depthwise_deconv_2d_3x3.cc @@ -14,8 +14,8 @@ #include +#include "mace/ops/arm/base/common_neon.h" #include "mace/ops/arm/base/depthwise_deconv_2d_3x3.h" -#include "mace/ops/arm/fp32/common_neon.h" namespace mace { namespace ops { diff --git a/mace/ops/arm/fp32/depthwise_deconv_2d_4x4.cc b/mace/ops/arm/fp32/depthwise_deconv_2d_4x4.cc index 529b728f..e7e98afa 100644 --- a/mace/ops/arm/fp32/depthwise_deconv_2d_4x4.cc +++ b/mace/ops/arm/fp32/depthwise_deconv_2d_4x4.cc @@ -14,8 +14,8 @@ #include +#include "mace/ops/arm/base/common_neon.h" #include "mace/ops/arm/base/depthwise_deconv_2d_4x4.h" -#include "mace/ops/arm/fp32/common_neon.h" namespace mace { namespace ops { diff --git a/mace/ops/arm/fp32/gemm.cc b/mace/ops/arm/fp32/gemm.cc index 123e3aae..4444c306 100644 --- a/mace/ops/arm/fp32/gemm.cc +++ b/mace/ops/arm/fp32/gemm.cc @@ -23,510 +23,6 @@ namespace mace { namespace ops { namespace arm { -template<> -template<> -void Gemm::Pack<4, 4>(const MatrixMap &matrix, - MatrixMajor dst_major, - float *packed_matrix) { - const index_t rows = matrix.rows(); - const index_t cols = matrix.cols(); - - // use the same terminology as GemmLowp: - // depth is depth, width is the opposite dim other than depth - // lhs - index_t width = rows; - index_t depth = cols; - index_t width_stride = matrix.rows_stride(); - index_t depth_stride = matrix.cols_stride(); - if (dst_major == RowMajor) { - // rhs - std::swap(width, depth); - std::swap(width_stride, depth_stride); - } - const float *data = matrix.data(); - float *packed_ptr = packed_matrix; - - const index_t block_size = 4; - const index_t depth_padded = RoundUp(depth, static_cast(4)); - - if (depth_padded > depth) { - memset(packed_ptr + depth * block_size, - 0, - sizeof(float) * (depth_padded - depth) * block_size); - } - - if (dst_major == matrix.matrix_major()) { - if (width < block_size) { - const index_t width_remain = block_size - width; - for (index_t d = 0; d < depth; ++d) { - memcpy(packed_ptr, data, sizeof(float) * width); - memset(packed_ptr + width, 0, sizeof(float) * width_remain); - data += depth_stride; - packed_ptr += block_size; - } - } else { - for (index_t d = 0; d < depth; ++d) { - float32x4_t vi = vld1q_f32(data); - vst1q_f32(packed_ptr, vi); - data += depth_stride; - packed_ptr += block_size; - } - } - } else { - if (width < block_size) { - const index_t width_remain = block_size - width; - for (index_t d = 0; d < depth; ++d) { - for (index_t w = 0; w < width; ++w) { - packed_ptr[w] = data[w * width_stride + d]; - } // w - memset(packed_ptr + width, 0, sizeof(float) * width_remain); - packed_ptr += block_size; - } // d - } else { - const float *data0 = data; - const float *data1 = data + width_stride; - const float *data2 = data1 + width_stride; - const float *data3 = data2 + width_stride; - - const index_t depth_block = depth / 4; - const index_t depth_remain = depth - depth_block * 4; - for (index_t depth_block_idx = 0; depth_block_idx < depth_block; - ++depth_block_idx) { - float32x4_t v0 = vld1q_f32(data0); - float32x4_t v1 = vld1q_f32(data1); - float32x4_t v2 = vld1q_f32(data2); - float32x4_t v3 = vld1q_f32(data3); - float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); - float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); - float32x4x2_t v0123_intertwined = - vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); - float32x4x2_t v0123n_intertwined = - vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); - - vst1q_f32(packed_ptr, v0123_intertwined.val[0]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v0123_intertwined.val[1]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v0123n_intertwined.val[0]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v0123n_intertwined.val[1]); - packed_ptr += 4; - - data0 += 4; - data1 += 4; - data2 += 4; - data3 += 4; - } - for (index_t d = 0; d < depth_remain; ++d) { - float32x4_t vi = {*data0, *data1, *data2, *data3}; - vst1q_f32(packed_ptr, vi); - packed_ptr += 4; - - ++data0; - ++data1; - ++data2; - ++data3; - } // d - } - } -} - -template<> -template<> -void Gemm::Pack<8, 4>(const MatrixMap &matrix, - MatrixMajor dst_major, - float *packed_matrix) { - const index_t rows = matrix.rows(); - const index_t cols = matrix.cols(); - - // use the same terminology as GemmLowp: - // depth is depth, width is the opposite dim other than depth - // lhs - index_t width = rows; - index_t depth = cols; - index_t width_stride = matrix.rows_stride(); - index_t depth_stride = matrix.cols_stride(); - if (dst_major == RowMajor) { - // rhs - std::swap(width, depth); - std::swap(width_stride, depth_stride); - } - const float *data = matrix.data(); - float *packed_ptr = packed_matrix; - - const index_t block_size = 8; - const index_t depth_padded = RoundUp(depth, static_cast(4)); - - if (depth_padded > depth) { - memset(packed_ptr + depth * block_size, - 0, - sizeof(float) * (depth_padded - depth) * block_size); - } - - if (dst_major == matrix.matrix_major()) { - if (width < block_size) { - const index_t width_remain = block_size - width; - for (index_t d = 0; d < depth; ++d) { - memcpy(packed_ptr, data, sizeof(float) * width); - memset(packed_ptr + width, 0, sizeof(float) * width_remain); - data += depth_stride; - packed_ptr += block_size; - } - } else { - for (index_t d = 0; d < depth; ++d) { - float32x4_t vi = vld1q_f32(data); - vst1q_f32(packed_ptr, vi); - float32x4_t vin = vld1q_f32(data + 4); - vst1q_f32(packed_ptr + 4, vin); - data += depth_stride; - packed_ptr += block_size; - } - } - } else { - if (width < block_size) { - const index_t width_remain = block_size - width; - for (index_t d = 0; d < depth; ++d) { - for (index_t w = 0; w < width; ++w) { - packed_ptr[w] = data[w * width_stride + d]; - } // w - memset(packed_ptr + width, 0, sizeof(float) * width_remain); - packed_ptr += block_size; - } // d - } else { - const float *data0 = data; - const float *data1 = data + width_stride; - const float *data2 = data1 + width_stride; - const float *data3 = data2 + width_stride; - const float *data4 = data3 + width_stride; - const float *data5 = data4 + width_stride; - const float *data6 = data5 + width_stride; - const float *data7 = data6 + width_stride; - - const index_t depth_block = depth / 4; - const index_t depth_remain = depth - depth_block * 4; - for (index_t depth_block_idx = 0; depth_block_idx < depth_block; - ++depth_block_idx) { - float32x4_t v0 = vld1q_f32(data0); - float32x4_t v1 = vld1q_f32(data1); - float32x4_t v2 = vld1q_f32(data2); - float32x4_t v3 = vld1q_f32(data3); - float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); - float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); - float32x4x2_t v0123_intertwined = - vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); - float32x4x2_t v0123n_intertwined = - vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); - - float32x4_t v4 = vld1q_f32(data4); - float32x4_t v5 = vld1q_f32(data5); - float32x4_t v6 = vld1q_f32(data6); - float32x4_t v7 = vld1q_f32(data7); - float32x4x2_t v46_intertwined = vzipq_f32(v4, v6); - float32x4x2_t v57_intertwined = vzipq_f32(v5, v7); - float32x4x2_t v4567_intertwined = - vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]); - float32x4x2_t v4567n_intertwined = - vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]); - - vst1q_f32(packed_ptr, v0123_intertwined.val[0]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v4567_intertwined.val[0]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v0123_intertwined.val[1]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v4567_intertwined.val[1]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v0123n_intertwined.val[0]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v4567n_intertwined.val[0]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v0123n_intertwined.val[1]); - packed_ptr += 4; - - vst1q_f32(packed_ptr, v4567n_intertwined.val[1]); - packed_ptr += 4; - - data0 += 4; - data1 += 4; - data2 += 4; - data3 += 4; - data4 += 4; - data5 += 4; - data6 += 4; - data7 += 4; - } - for (index_t d = 0; d < depth_remain; ++d) { - float32x4_t vi = {*data0, *data1, *data2, *data3}; - vst1q_f32(packed_ptr, vi); - packed_ptr += 4; - - float32x4_t vin = {*data4, *data5, *data6, *data7}; - vst1q_f32(packed_ptr, vin); - packed_ptr += 4; - - ++data0; - ++data1; - ++data2; - ++data3; - ++data4; - ++data5; - ++data6; - ++data7; - } // d - } - } -} - -template<> -template<> -void Gemm::Unpack<4, 8>(const float *packed_output, - MatrixMap *output) { - const index_t rows = output->rows(); - const index_t cols = output->cols(); - index_t row_stride = output->rows_stride(); - index_t col_stride = output->cols_stride(); - - float *output_ptr = output->data(); - const float *packed_ptr = packed_output; - - const index_t block_size = 8; - - // packed_output always has row-major - if (output->matrix_major() == RowMajor) { - if (cols < block_size) { - for (index_t r = 0; r < rows; ++r) { - memcpy(output_ptr, packed_ptr, sizeof(float) * cols); - output_ptr += row_stride; - packed_ptr += block_size; - } - } else { - for (index_t r = 0; r < rows; ++r) { - float32x4_t vi = vld1q_f32(packed_ptr); - vst1q_f32(output_ptr, vi); - float32x4_t vin = vld1q_f32(packed_ptr + 4); - vst1q_f32(output_ptr + 4, vin); - - output_ptr += row_stride; - packed_ptr += block_size; - } - } - } else { - // ColMajor - if (rows < block_size) { - for (index_t c = 0; c < cols; ++c) { - for (index_t r = 0; r < rows; ++r) { - output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; - } // r - } // c - } else { - const float *data0 = packed_ptr; - const float *data1 = data0 + block_size; - const float *data2 = data1 + block_size; - const float *data3 = data2 + block_size; - - index_t col_block = cols / 4; - index_t col_remain = cols - col_block * 4; - for (index_t col_block_idx = 0; col_block_idx < col_block; - ++col_block_idx) { - float32x4_t v0 = vld1q_f32(data0); - float32x4_t v1 = vld1q_f32(data1); - float32x4_t v2 = vld1q_f32(data2); - float32x4_t v3 = vld1q_f32(data3); - float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); - float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); - float32x4x2_t v0123_intertwined = - vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); - float32x4x2_t v0123n_intertwined = - vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); - - vst1q_f32(output_ptr, v0123_intertwined.val[0]); - output_ptr += col_stride; - - vst1q_f32(output_ptr, v0123_intertwined.val[1]); - output_ptr += col_stride; - - vst1q_f32(output_ptr, v0123n_intertwined.val[0]); - output_ptr += col_stride; - - vst1q_f32(output_ptr, v0123n_intertwined.val[1]); - output_ptr += col_stride; - - data0 += 4; - data1 += 4; - data2 += 4; - data3 += 4; - } - for (index_t c = 0; c < col_remain; ++c) { - float32x4_t vi = {*data0, *data1, *data2, *data3}; - vst1q_f32(output_ptr, vi); - output_ptr += col_stride; - - ++data0; - ++data1; - ++data2; - ++data3; - } // d - } - } -} - -template<> -template<> -void Gemm::Unpack<8, 8>(const float *packed_output, - MatrixMap *output) { - const index_t rows = output->rows(); - const index_t cols = output->cols(); - index_t row_stride = output->rows_stride(); - index_t col_stride = output->cols_stride(); - - float *output_ptr = output->data(); - const float *packed_ptr = packed_output; - - const index_t block_size = 8; - - // packed_output always has row-major - if (output->matrix_major() == RowMajor) { - if (cols < block_size) { - for (index_t r = 0; r < rows; ++r) { - memcpy(output_ptr, packed_ptr, sizeof(float) * cols); - output_ptr += row_stride; - packed_ptr += block_size; - } - } else { - for (index_t r = 0; r < rows; ++r) { - float32x4_t vi = vld1q_f32(packed_ptr); - vst1q_f32(output_ptr, vi); - float32x4_t vin = vld1q_f32(packed_ptr + 4); - vst1q_f32(output_ptr + 4, vin); - - output_ptr += row_stride; - packed_ptr += block_size; - } - } - } else { - // ColMajor - if (rows < block_size) { - for (index_t c = 0; c < cols; ++c) { - for (index_t r = 0; r < rows; ++r) { - output_ptr[c * col_stride + r] = packed_ptr[r * block_size + c]; - } // r - } // c - } else { - const float *data0 = packed_ptr; - const float *data1 = data0 + block_size; - const float *data2 = data1 + block_size; - const float *data3 = data2 + block_size; - const float *data4 = data3 + block_size; - const float *data5 = data4 + block_size; - const float *data6 = data5 + block_size; - const float *data7 = data6 + block_size; - - index_t col_block = cols / 4; - index_t col_remain = cols - col_block * 4; - for (index_t col_block_idx = 0; col_block_idx < col_block; - ++col_block_idx) { - float32x4_t v0 = vld1q_f32(data0); - float32x4_t v1 = vld1q_f32(data1); - float32x4_t v2 = vld1q_f32(data2); - float32x4_t v3 = vld1q_f32(data3); - float32x4x2_t v02_intertwined = vzipq_f32(v0, v2); - float32x4x2_t v13_intertwined = vzipq_f32(v1, v3); - float32x4x2_t v0123_intertwined = - vzipq_f32(v02_intertwined.val[0], v13_intertwined.val[0]); - float32x4x2_t v0123n_intertwined = - vzipq_f32(v02_intertwined.val[1], v13_intertwined.val[1]); - - float32x4_t v4 = vld1q_f32(data4); - float32x4_t v5 = vld1q_f32(data5); - float32x4_t v6 = vld1q_f32(data6); - float32x4_t v7 = vld1q_f32(data7); - float32x4x2_t v46_intertwined = vzipq_f32(v4, v6); - float32x4x2_t v57_intertwined = vzipq_f32(v5, v7); - float32x4x2_t v4567_intertwined = - vzipq_f32(v46_intertwined.val[0], v57_intertwined.val[0]); - float32x4x2_t v4567n_intertwined = - vzipq_f32(v46_intertwined.val[1], v57_intertwined.val[1]); - - vst1q_f32(output_ptr, v0123_intertwined.val[0]); - vst1q_f32(output_ptr + 4, v4567_intertwined.val[0]); - output_ptr += col_stride; - - vst1q_f32(output_ptr, v0123_intertwined.val[1]); - vst1q_f32(output_ptr + 4, v4567_intertwined.val[1]); - output_ptr += col_stride; - - vst1q_f32(output_ptr, v0123n_intertwined.val[0]); - vst1q_f32(output_ptr + 4, v4567n_intertwined.val[0]); - output_ptr += col_stride; - - vst1q_f32(output_ptr, v0123n_intertwined.val[1]); - vst1q_f32(output_ptr + 4, v4567n_intertwined.val[1]); - output_ptr += col_stride; - - data0 += 4; - data1 += 4; - data2 += 4; - data3 += 4; - data4 += 4; - data5 += 4; - data6 += 4; - data7 += 4; - } - for (index_t c = 0; c < col_remain; ++c) { - float32x4_t vi = {*data0, *data1, *data2, *data3}; - vst1q_f32(output_ptr, vi); - float32x4_t vin = {*data4, *data5, *data6, *data7}; - vst1q_f32(output_ptr + 4, vin); - output_ptr += col_stride; - - ++data0; - ++data1; - ++data2; - ++data3; - ++data4; - ++data5; - ++data6; - ++data7; - } // d - } - } -} - -template<> -void Gemm::PackLhs(const MatrixMap &lhs, - float *packed_lhs) { -#ifdef __aarch64__ - Pack<8, 4>(lhs, ColMajor, packed_lhs); -#else - Pack<4, 4>(lhs, ColMajor, packed_lhs); -#endif -} - -template<> -void Gemm::PackRhs(const MatrixMap &rhs, - float *packed_rhs) { - Pack<8, 4>(rhs, RowMajor, packed_rhs); -} - -template<> -void Gemm::UnpackOutput(const float *packed_output, - MatrixMap *output) { -#ifdef __aarch64__ - Unpack<8, 8>(packed_output, output); -#else - Unpack<4, 8>(packed_output, output); -#endif -} - template<> void Gemm::ComputeBlock(const float *packed_lhs_data, const float *packed_rhs_data, @@ -1008,190 +504,6 @@ void Gemm::ComputeBlock(const float *packed_lhs_data, #endif } -template<> -MaceStatus Gemm::Compute(const OpContext *context, - const Tensor *lhs, - const Tensor *rhs, - const index_t batch, - const index_t rows, - const index_t cols, - const index_t depth, - const MatrixMajor lhs_major, - const MatrixMajor rhs_major, - const MatrixMajor output_major, - const bool lhs_batched, - const bool rhs_batched, - Tensor *output) { - MACE_CHECK(output->size() == batch * rows * cols, - "Need resize output tensor before call gemm."); - Tensor::MappingGuard lhs_guard(lhs); - Tensor::MappingGuard rhs_guard(rhs); - Tensor::MappingGuard output_guard(output); - const float *lhs_data = lhs->data(); - const float *rhs_data = rhs->data(); - float *output_data = output->mutable_data(); - -#ifdef __aarch64__ - const index_t row_block_size = 8; -#else - const index_t row_block_size = 4; -#endif - const index_t col_block_size = 8; - const index_t depth_block_size = 4; - const index_t row_block_count = RoundUpDiv(rows, row_block_size); - const index_t col_block_count = RoundUpDiv(cols, col_block_size); - const index_t rows_padded = RoundUp(rows, row_block_size); - const index_t cols_padded = RoundUp(cols, col_block_size); - const index_t depth_padded = RoundUp(depth, depth_block_size); - - ScratchBuffer *scratch = context->device()->scratch_buffer(); - - index_t packed_lhs_size = - PadAlignSize(sizeof(float) * rows_padded * depth_padded); - index_t packed_rhs_size = - PadAlignSize(sizeof(float) * depth_padded * cols_padded); - index_t packed_output_size = - PadAlignSize(sizeof(float) * rows_padded * cols_padded); - // resize to the total size of lhs & rhs & output anyway, - // in case we do not cache const tensor for saving memory - MACE_RETURN_IF_ERROR(scratch->GrowSize( - packed_lhs_size + packed_rhs_size + packed_output_size)); - float *packed_lhs_data = - scratch->Scratch(packed_lhs_size).mutable_data(); - float *packed_rhs_data = - scratch->Scratch(packed_rhs_size).mutable_data(); - float *packed_output_data = - scratch->Scratch(packed_output_size).mutable_data(); - - int cache_side = kNoCache; - if (cached_ == kCacheLhs) { - packed_lhs_data = pack_cache_.mutable_data(); - } else if (cached_ == kCacheRhs) { - packed_rhs_data = pack_cache_.mutable_data(); - } else if (should_cache_pack_) { - if (lhs->is_weight() && (!lhs_batched || batch == 1)) { - cache_side = kCacheLhs; - pack_cache_.Resize(packed_lhs_size); - packed_lhs_data = pack_cache_.mutable_data(); - } else if (rhs->is_weight() && (!rhs_batched || batch == 1)) { - cache_side = kCacheRhs; - pack_cache_.Resize(packed_rhs_size); - packed_rhs_data = pack_cache_.mutable_data(); - } - } - - utils::ThreadPool - &thread_pool = context->device()->cpu_runtime()->thread_pool(); - - for (index_t b = 0; b < batch; ++b) { - MatrixMap - lhs_matrix - (lhs_data + static_cast(lhs_batched) * b * rows * depth, - lhs_major, - rows, - depth); - MatrixMap - rhs_matrix - (rhs_data + static_cast(rhs_batched) * b * depth * cols, - rhs_major, - depth, - cols); - MatrixMap output_matrix - (output_data + b * rows * cols, output_major, rows, cols); - - // pack lhs - if (cached_ != kCacheLhs) { - thread_pool.Compute1D([=, &lhs_matrix](index_t start, - index_t end, - index_t step) { - for (index_t row_block_idx = start; row_block_idx < end; - row_block_idx += step) { - const index_t start_row = row_block_idx * row_block_size; - const index_t - row_block_len = std::min(row_block_size, rows - start_row); - float *packed_lhs_data_block = - packed_lhs_data + row_block_idx * row_block_size * depth_padded; - PackLhs(lhs_matrix.block(start_row, 0, row_block_len, depth), - packed_lhs_data_block); - } - }, 0, row_block_count, 1); - - if (cache_side == kCacheLhs) { - cached_ = kCacheLhs; - if (lhs->UnderlyingBuffer()->OnHost()) { - AdviseFree(reinterpret_cast(const_cast(lhs->data< - float>())), - lhs->raw_size()); - } - } - } - - // pack rhs - if (cached_ != kCacheRhs) { - thread_pool.Compute1D([=, &rhs_matrix](index_t start, - index_t end, - index_t step) { - for (index_t col_block_idx = start; col_block_idx < end; - col_block_idx += step) { - const index_t start_col = col_block_idx * col_block_size; - const index_t - col_block_len = std::min(col_block_size, cols - start_col); - float *packed_rhs_data_block = - packed_rhs_data + col_block_idx * col_block_size * depth_padded; - PackRhs(rhs_matrix.block(0, start_col, depth, col_block_len), - packed_rhs_data_block); - } - }, 0, col_block_count, 1); - - if (cache_side == kCacheRhs) { - cached_ = kCacheRhs; - if (rhs->UnderlyingBuffer()->OnHost()) { - AdviseFree(reinterpret_cast(const_cast(rhs->data< - float>())), - rhs->raw_size()); - } - } - } - - // multiply lhs and rhs - thread_pool.Compute1D([=, &output_matrix](index_t start, - index_t end, - index_t step) { - for (index_t row_block_idx = start; row_block_idx < end; - row_block_idx += step) { - const index_t start_row = row_block_idx * row_block_size; - const index_t - row_block_len = std::min(row_block_size, rows - start_row); - const float *packed_lhs_data_block = - packed_lhs_data + row_block_idx * row_block_size * depth_padded; - - for (index_t col_block_idx = 0; col_block_idx < col_block_count; - ++col_block_idx) { - const index_t start_col = col_block_idx * col_block_size; - const index_t - col_block_len = std::min(col_block_size, cols - start_col); - const float *packed_rhs_data_block = - packed_rhs_data + col_block_idx * col_block_size * depth_padded; - float *packed_output_data_block = - packed_output_data + row_block_idx * row_block_size * cols_padded - + col_block_idx * col_block_size; - ComputeBlock(packed_lhs_data_block, - packed_rhs_data_block, - depth_padded, - packed_output_data_block); - MatrixMap output_block = output_matrix.block(start_row, - start_col, - row_block_len, - col_block_len); - UnpackOutput(packed_output_data_block, &output_block); - } // col_block_idx - } // row_block_idx - }, 0, row_block_count, 1); - } // b - - return MaceStatus::MACE_SUCCESS; -} - } // namespace arm } // namespace ops } // namespace mace diff --git a/mace/ops/registry/op_delegators_registry.cc b/mace/ops/registry/op_delegators_registry.cc index 9ef615d5..b7c86133 100644 --- a/mace/ops/registry/op_delegators_registry.cc +++ b/mace/ops/registry/op_delegators_registry.cc @@ -37,9 +37,7 @@ extern void RegisterGemvDelegator(OpDelegatorRegistry *registry); #ifdef MACE_ENABLE_NEON namespace arm { -namespace fp32 { extern void RegisterConv2dK3x3WinogradDelegator(OpDelegatorRegistry *registry); -} // namespace fp32 extern void RegisterActivationDelegator(OpDelegatorRegistry *registry); extern void RegisterBiasAddDelegator(OpDelegatorRegistry *registry); @@ -98,7 +96,7 @@ void RegisterAllOpDelegators(OpDelegatorRegistry *registry) { #endif // MACE_ENABLE_QUANTIZE #ifdef MACE_ENABLE_NEON - arm::fp32::RegisterConv2dK3x3WinogradDelegator(registry); + arm::RegisterConv2dK3x3WinogradDelegator(registry); arm::RegisterActivationDelegator(registry); arm::RegisterBiasAddDelegator(registry); diff --git a/test/ccbenchmark/BUILD.bazel b/test/ccbenchmark/BUILD.bazel index c4581a8b..ea5357ca 100644 --- a/test/ccbenchmark/BUILD.bazel +++ b/test/ccbenchmark/BUILD.bazel @@ -10,6 +10,7 @@ load( "if_android_armv7", "if_hexagon_enabled", "if_neon_enabled", + "if_bfloat16_enabled", "if_opencl_enabled", "if_quantize_enabled", ) @@ -58,6 +59,8 @@ cc_test( "-DMACE_ENABLE_OPENCL", ]) + if_quantize_enabled([ "-DMACE_ENABLE_QUANTIZE", + ]) + if_bfloat16_enabled([ + "-DMACE_ENABLE_BFLOAT16", ]) + if_hexagon_enabled([ "-DMACE_ENABLE_HEXAGON", ]), diff --git a/test/ccbenchmark/mace/ops/activation_benchmark.cc b/test/ccbenchmark/mace/ops/activation_benchmark.cc index ee92d352..a991e6d0 100644 --- a/test/ccbenchmark/mace/ops/activation_benchmark.cc +++ b/test/ccbenchmark/mace/ops/activation_benchmark.cc @@ -67,15 +67,24 @@ void ReluBenchmark(int iters, int batch, int channels, int height, int width) { } \ MACE_BENCHMARK(MACE_BM_RELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_RELU_BF16_MACRO(N, C, H, W) \ + MACE_BM_RELU_MACRO(N, C, H, W, BFloat16, CPU) +#else +#define MACE_BM_RELU_BF16_MACRO(N, C, H, W) +#endif // MACE_ENABLE_BFLOAT16 #ifdef MACE_ENABLE_OPENCL -#define MACE_BM_RELU(N, C, H, W) \ - MACE_BM_RELU_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_RELU_MACRO(N, C, H, W, float, GPU); \ +#define MACE_BM_RELU_GPU_MACRO(N, C, H, W) \ + MACE_BM_RELU_MACRO(N, C, H, W, float, GPU); \ MACE_BM_RELU_MACRO(N, C, H, W, half, GPU) #else -#define MACE_BM_RELU(N, C, H, W) \ - MACE_BM_RELU_MACRO(N, C, H, W, float, CPU) -#endif +#define MACE_BM_RELU_GPU_MACRO(N, C, H, W) +#endif // MACE_ENABLE_OPENCL + +#define MACE_BM_RELU(N, C, H, W) \ + MACE_BM_RELU_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_RELU_BF16_MACRO(N, C, H, W); \ + MACE_BM_RELU_GPU_MACRO(N, C, H, W) MACE_BM_RELU(1, 1, 512, 512); MACE_BM_RELU(1, 3, 128, 128); @@ -128,15 +137,24 @@ void ReluxBenchmark(int iters, int batch, int channels, int height, int width) { } \ MACE_BENCHMARK(MACE_BM_RELUX_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_RELUX_BF16_MACRO(N, C, H, W) \ + MACE_BM_RELUX_MACRO(N, C, H, W, BFloat16, CPU) +#else +#define MACE_BM_RELUX_BF16_MACRO(N, C, H, W) +#endif // MACE_ENABLE_BFLOAT16 #ifdef MACE_ENABLE_OPENCL -#define MACE_BM_RELUX(N, C, H, W) \ - MACE_BM_RELUX_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_RELUX_MACRO(N, C, H, W, float, GPU); \ +#define MACE_BM_RELUX_GPU_MACRO(N, C, H, W) \ + MACE_BM_RELUX_MACRO(N, C, H, W, float, GPU); \ MACE_BM_RELUX_MACRO(N, C, H, W, half, GPU) #else -#define MACE_BM_RELUX(N, C, H, W) \ - MACE_BM_RELUX_MACRO(N, C, H, W, float, CPU) -#endif +#define MACE_BM_RELUX_GPU_MACRO(N, C, H, W) +#endif // MACE_ENABLE_OPENCL + +#define MACE_BM_RELUX(N, C, H, W) \ + MACE_BM_RELUX_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_RELUX_BF16_MACRO(N, C, H, W); \ + MACE_BM_RELUX_GPU_MACRO(N, C, H, W) MACE_BM_RELUX(1, 1, 512, 512); MACE_BM_RELUX(1, 3, 128, 128); @@ -192,15 +210,24 @@ void PreluBenchmark(int iters, int batch, int channels, int height, int width) { } \ MACE_BENCHMARK(MACE_BM_PRELU_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_PRELU_BF16_MACRO(N, C, H, W) \ + MACE_BM_PRELU_MACRO(N, C, H, W, BFloat16, CPU) +#else +#define MACE_BM_PRELU_BF16_MACRO(N, C, H, W) +#endif // MACE_ENABLE_BFLOAT16 #ifdef MACE_ENABLE_OPENCL -#define MACE_BM_PRELU(N, C, H, W) \ - MACE_BM_PRELU_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_PRELU_MACRO(N, C, H, W, float, GPU); \ +#define MACE_BM_PRELU_GPU_MACRO(N, C, H, W) \ + MACE_BM_PRELU_MACRO(N, C, H, W, float, GPU); \ MACE_BM_PRELU_MACRO(N, C, H, W, half, GPU) #else -#define MACE_BM_PRELU(N, C, H, W) \ - MACE_BM_PRELU_MACRO(N, C, H, W, float, CPU) -#endif +#define MACE_BM_PRELU_GPU_MACRO(N, C, H, W) +#endif // MACE_ENABLE_OPENCL + +#define MACE_BM_PRELU(N, C, H, W) \ + MACE_BM_PRELU_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_PRELU_BF16_MACRO(N, C, H, W); \ + MACE_BM_PRELU_GPU_MACRO(N, C, H, W) MACE_BM_PRELU(1, 1, 512, 512); MACE_BM_PRELU(1, 3, 128, 128); @@ -316,15 +343,24 @@ void TanhBenchmark(int iters, int batch, int channels, int height, int width) { } \ MACE_BENCHMARK(MACE_BM_TANH_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_TANH_BF16_MACRO(N, C, H, W) \ + MACE_BM_TANH_MACRO(N, C, H, W, BFloat16, CPU) +#else +#define MACE_BM_TANH_BF16_MACRO(N, C, H, W) +#endif // MACE_ENABLE_BFLOAT16 #ifdef MACE_ENABLE_OPENCL -#define MACE_BM_TANH(N, C, H, W) \ - MACE_BM_TANH_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_TANH_MACRO(N, C, H, W, float, GPU); \ +#define MACE_BM_TANH_GPU_MACRO(N, C, H, W) \ + MACE_BM_TANH_MACRO(N, C, H, W, float, GPU); \ MACE_BM_TANH_MACRO(N, C, H, W, half, GPU) #else -#define MACE_BM_TANH(N, C, H, W) \ - MACE_BM_TANH_MACRO(N, C, H, W, float, CPU) -#endif +#define MACE_BM_TANH_GPU_MACRO(N, C, H, W) +#endif // MACE_ENABLE_OPENCL + +#define MACE_BM_TANH(N, C, H, W) \ + MACE_BM_TANH_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_TANH_BF16_MACRO(N, C, H, W); \ + MACE_BM_TANH_GPU_MACRO(N, C, H, W) MACE_BM_TANH(1, 1, 512, 512); MACE_BM_TANH(1, 3, 128, 128); @@ -377,15 +413,24 @@ void SigmoidBenchmark( } \ MACE_BENCHMARK(MACE_BM_SIGMOID_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_SIGMOID_BF16_MACRO(N, C, H, W) \ + MACE_BM_SIGMOID_MACRO(N, C, H, W, BFloat16, CPU) +#else +#define MACE_BM_SIGMOID_BF16_MACRO(N, C, H, W) +#endif // MACE_ENABLE_BFLOAT16 #ifdef MACE_ENABLE_OPENCL -#define MACE_BM_SIGMOID(N, C, H, W) \ - MACE_BM_SIGMOID_MACRO(N, C, H, W, float, CPU); \ +#define MACE_BM_SIGMOID_GPU_MACRO(N, C, H, W) \ MACE_BM_SIGMOID_MACRO(N, C, H, W, float, GPU); \ MACE_BM_SIGMOID_MACRO(N, C, H, W, half, GPU) #else +#define MACE_BM_SIGMOID_GPU_MACRO(N, C, H, W) +#endif // MACE_ENABLE_OPENCL + #define MACE_BM_SIGMOID(N, C, H, W) \ - MACE_BM_SIGMOID_MACRO(N, C, H, W, float, CPU) -#endif + MACE_BM_SIGMOID_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_SIGMOID_BF16_MACRO(N, C, H, W); \ + MACE_BM_SIGMOID_GPU_MACRO(N, C, H, W) MACE_BM_SIGMOID(1, 1, 512, 512); MACE_BM_SIGMOID(1, 3, 128, 128); diff --git a/test/ccbenchmark/mace/ops/bias_add_benchmark.cc b/test/ccbenchmark/mace/ops/bias_add_benchmark.cc index 477fc04a..4a7fb2b4 100644 --- a/test/ccbenchmark/mace/ops/bias_add_benchmark.cc +++ b/test/ccbenchmark/mace/ops/bias_add_benchmark.cc @@ -68,24 +68,31 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { } \ MACE_BENCHMARK(MACE_BM_BIAS_ADD_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE) -#define MACE_BM_BIAS_ADD(N, C, H, W) \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, uint8_t, CPU); \ +#ifdef MACE_ENABLE_QUANTIZE +#define MACE_BM_BIAS_ADD_Q8_MACRO(N, C, H, W) \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, uint8_t, CPU) +#else +#define MACE_BM_BIAS_ADD_Q8_MACRO(N, C, H, W) +#endif // MACE_ENABLE_QUANTIZE +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_BIAS_ADD_BF16_MACRO(N, C, H, W) \ + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, BFloat16, CPU) +#else +#define MACE_BM_BIAS_ADD_BF16_MACRO(N, C, H, W) +#endif // MACE_ENABLE_BFLOAT16 +#ifdef MACE_ENABLE_OPENCL +#define MACE_BM_BIAS_ADD_GPU_MACRO(N, C, H, W) \ MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, GPU); \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, half, GPU); -#elif defined(MACE_ENABLE_OPENCL) -#define MACE_BM_BIAS_ADD(N, C, H, W) \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, GPU); \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, half, GPU); -#elif defined(MACE_ENABLE_QUANTIZE) + MACE_BM_BIAS_ADD_MACRO(N, C, H, W, half, GPU) +#else +#define MACE_BM_BIAS_ADD_GPU_MACRO(N, C, H, W) +#endif // MACE_ENABLE_OPENCL + #define MACE_BM_BIAS_ADD(N, C, H, W) \ MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, uint8_t, CPU); -#define MACE_BM_BIAS_ADD(N, C, H, W) \ - MACE_BM_BIAS_ADD_MACRO(N, C, H, W, float, CPU); -#endif + MACE_BM_BIAS_ADD_Q8_MACRO(N, C, H, W); \ + MACE_BM_BIAS_ADD_BF16_MACRO(N, C, H, W); \ + MACE_BM_BIAS_ADD_GPU_MACRO(N, C, H, W) MACE_BM_BIAS_ADD(1, 1, 512, 512); MACE_BM_BIAS_ADD(1, 3, 128, 128); diff --git a/test/ccbenchmark/mace/ops/conv_2d_benchmark.cc b/test/ccbenchmark/mace/ops/conv_2d_benchmark.cc index 3de9a756..84ad04d3 100644 --- a/test/ccbenchmark/mace/ops/conv_2d_benchmark.cc +++ b/test/ccbenchmark/mace/ops/conv_2d_benchmark.cc @@ -42,7 +42,7 @@ void Conv2d(int iters, // Add input data if (D == DeviceType::CPU) { - net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Input", {batch, channels, height, width}); } else if (D == DeviceType::GPU) { net.AddRandomInput("Input", {batch, height, width, channels}); } else { @@ -169,26 +169,31 @@ void Conv2d(int iters, MACE_BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##D##\ DILATION##_##P##_##OC##_##TYPE##_##DEVICE) -#if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE) -#define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, GPU); \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, GPU); \ +#ifdef MACE_ENABLE_QUANTIZE +#define MACE_BM_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, D, P, OC) \ MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, uint8_t, CPU) -#elif defined(MACE_ENABLE_OPENCL) -#define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \ +#else +#define MACE_BM_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, D, P, OC) +#endif // MACE_ENABLE_QUANTIZE +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, D, P, OC) \ + MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, BFloat16, CPU) +#else +#define MACE_BM_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, D, P, OC) +#endif // MACE_ENABLE_BFLOAT16 +#ifdef MACE_ENABLE_OPENCL +#define MACE_BM_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, D, P, OC) \ MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, GPU); \ MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, GPU) -#elif defined(MACE_ENABLE_QUANTIZE) -#define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, uint8_t, CPU) #else -#define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ - MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU) -#endif +#define MACE_BM_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, D, P, OC) +#endif // MACE_ENABLE_OPENCL +#define MACE_BM_CONV_2D(N, C, H, W, KH, KW, S, D, P, OC) \ + MACE_BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, CPU); \ + MACE_BM_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, D, P, OC); \ + MACE_BM_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, D, P, OC); \ + MACE_BM_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, D, P, OC) // Filter sizes and data alignments MACE_BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, 1, VALID, 128); diff --git a/test/ccbenchmark/mace/ops/depthwise_conv2d_benchmark.cc b/test/ccbenchmark/mace/ops/depthwise_conv2d_benchmark.cc index 7a72d6e2..3abf5067 100644 --- a/test/ccbenchmark/mace/ops/depthwise_conv2d_benchmark.cc +++ b/test/ccbenchmark/mace/ops/depthwise_conv2d_benchmark.cc @@ -128,25 +128,31 @@ void DepthwiseConv2d(int iters, MACE_BM_DEPTHWISE_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE\ ##_##P##_##M##_##TYPE##_##DEVICE) -#if defined(MACE_ENABLE_OPENCL) && defined(MACE_ENABLE_QUANTIZE) -#define MACE_BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, GPU); \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, GPU); \ +#ifdef MACE_ENABLE_QUANTIZE +#define MACE_BM_DEPTHWISE_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, P, M) \ MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, uint8_t, CPU) -#elif defined(MACE_ENABLE_OPENCL) -#define MACE_BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \ +#else +#define MACE_BM_DEPTHWISE_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, P, M) +#endif // MACE_ENABLE_QUANTIZE +#ifdef MACE_ENABLE_BFLOAT16 +#define MACE_BM_DEPTHWISE_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, P, M) \ + MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, BFloat16, CPU) +#else +#define MACE_BM_DEPTHWISE_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, P, M) +#endif // MACE_ENABLE_BFLOAT16 +#ifdef MACE_ENABLE_OPENCL +#define MACE_BM_DEPTHWISE_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, P, M) \ MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, GPU); \ MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, GPU) -#elif defined(MACE_ENABLE_QUANTIZE) -#define MACE_BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, uint8_t, CPU) #else +#define MACE_BM_DEPTHWISE_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, P, M) +#endif // MACE_ENABLE_OPENCL + #define MACE_BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU) -#endif + MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \ + MACE_BM_DEPTHWISE_CONV_2D_Q8_MACRO(N, C, H, W, KH, KW, S, P, M); \ + MACE_BM_DEPTHWISE_CONV_2D_BF16_MACRO(N, C, H, W, KH, KW, S, P, M); \ + MACE_BM_DEPTHWISE_CONV_2D_GPU_MACRO(N, C, H, W, KH, KW, S, P, M) MACE_BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1); MACE_BM_DEPTHWISE_CONV_2D(1, 32, 56, 56, 3, 3, 2, VALID, 1); diff --git a/test/ccunit/BUILD.bazel b/test/ccunit/BUILD.bazel index 75a6ce78..50d12d4e 100644 --- a/test/ccunit/BUILD.bazel +++ b/test/ccunit/BUILD.bazel @@ -11,6 +11,7 @@ load( "if_hexagon_enabled", "if_hta_enabled", "if_neon_enabled", + "if_bfloat16_enabled", "if_opencl_enabled", "if_quantize_enabled", ) @@ -37,13 +38,19 @@ cc_test( "mace/ops/arm/q8/*.cc", "mace/ops/fixpoint_test.cc", ] + )) + if_bfloat16_enabled(glob( + [ + "mace/ops/arm/bf16/*.cc", + ] )) + if_opencl_enabled(glob( [ "mace/ops/opencl/*.cc", ] - )) + if_hta_enabled([ - "mace/core/runtime/hexagon/hta_transform_test.cc", - ]), + )) + if_hta_enabled( + [ + "mace/core/runtime/hexagon/hta_transform_test.cc", + ] + ), copts = [ "-Werror", "-Wextra", @@ -57,6 +64,8 @@ cc_test( "-DMACE_ENABLE_OPENCL", ]) + if_quantize_enabled([ "-DMACE_ENABLE_QUANTIZE", + ]) + if_bfloat16_enabled([ + "-DMACE_ENABLE_BFLOAT16", ]) + if_hexagon_enabled([ "-DMACE_ENABLE_HEXAGON", ]) + if_hta_enabled([ diff --git a/test/ccunit/mace/ops/activation_test.cc b/test/ccunit/mace/ops/activation_test.cc index dfa978e7..f9b8bdb0 100644 --- a/test/ccunit/mace/ops/activation_test.cc +++ b/test/ccunit/mace/ops/activation_test.cc @@ -417,6 +417,61 @@ TEST_F(ActivationOpTest, Quantized) { TestQuantized(37, "RELUX"); } +#ifdef MACE_ENABLE_BFLOAT16 +namespace { +void TestBFloat16(const char *activation) { + OpsTestNet net; + + static unsigned int seed = time(NULL); + index_t batch = 3 + (rand_r(&seed) % 10); + index_t channels = 3 + (rand_r(&seed) % 10); + index_t height = 3 + (rand_r(&seed) % 10); + index_t width = 3 + (rand_r(&seed) % 10); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Alpha", {channels}, true); + net.Cast("Input", "BF16Input"); + net.Cast("Alpha", "BF16Alpha"); + + OpDefBuilder("Activation", "ActivationTest") + .Input("Input") + .Input("Alpha") + .Output("Output") + .AddStringArg("activation", activation) + .AddFloatArg("leakyrelu_coefficient", 0.1) + .AddFloatArg("max_limit", 6) + .AddIntArg("T", static_cast(DT_FLOAT)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + OpDefBuilder("Activation", "BF16ActivationTest") + .Input("BF16Input") + .Input("BF16Alpha") + .Output("BF16Output") + .AddStringArg("activation", activation) + .AddFloatArg("leakyrelu_coefficient", 0.1) + .AddFloatArg("max_limit", 6) + .AddIntArg("T", static_cast(DT_BFLOAT16)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + net.Cast("BF16Output", "CastOutput"); + + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("CastOutput"), 1e-5); +} +} // namespace + +TEST_F(ActivationOpTest, BFloat16) { + TestBFloat16("RELU"); + TestBFloat16("LEAKYRELU"); + TestBFloat16("RELUX"); + TestBFloat16("PRELU"); + TestBFloat16("TANH"); + TestBFloat16("SIGMOID"); +} +#endif // MACE_ENABLE_BFLOAT16 } // namespace test } // namespace ops } // namespace mace diff --git a/test/ccunit/mace/ops/arm/bf16/gemm_test.cc b/test/ccunit/mace/ops/arm/bf16/gemm_test.cc new file mode 100644 index 00000000..eb6d6bb2 --- /dev/null +++ b/test/ccunit/mace/ops/arm/bf16/gemm_test.cc @@ -0,0 +1,106 @@ +// Copyright 2020 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/delegator/gemm.h" + +#include + +#include "mace/core/ops/op_context.h" +#include "mace/core/tensor.h" +#include "mace/ops/ops_test_util.h" +#include "mace/ops/testing/test_utils.h" + +namespace mace { +namespace ops { +namespace test { + +void TestGemmBFloat16(const index_t batch, + const index_t rows, + const index_t cols, + const index_t depth, + const MatrixMajor lhs_major, + const MatrixMajor rhs_major, + const MatrixMajor output_major, + const bool lhs_batched, + const bool rhs_batched) { + Tensor lhs(GetCPUAllocator(), DT_BFLOAT16); + Tensor rhs(GetCPUAllocator(), DT_BFLOAT16); + Tensor output(GetCPUAllocator(), DT_BFLOAT16); + lhs.Resize({lhs_batched ? batch : 1, rows, depth}); + rhs.Resize({rhs_batched ? batch : 1, depth, cols}); + output.Resize({batch, rows, cols}); + { + Tensor::MappingGuard lhs_guard(&lhs); + Tensor::MappingGuard rhs_guard(&rhs); + auto lhs_data = lhs.mutable_data(); + auto rhs_data = rhs.mutable_data(); + auto output_data = output.mutable_data(); + GenerateRandomRealTypeData(lhs.shape(), lhs_data); + GenerateRandomRealTypeData(rhs.shape(), rhs_data); + GenerateRandomRealTypeData(output.shape(), output_data); + } + + utils::ThreadPool thread_pool(1, AFFINITY_NONE); + thread_pool.Init(); + CPUDevice cpu_device(1, AFFINITY_NONE, &thread_pool); + OpsTestNet net; + OpContext context(net.ws(), &cpu_device); + std::unique_ptr gemm = delegator::Gemm::Create( + context.workspace(), + MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, BFloat16, ImplType::NEON), + delegator::GemmParam()); + gemm->Compute(&context, &lhs, &rhs, batch, rows, cols, depth, lhs_major, + rhs_major, output_major, lhs_batched, rhs_batched, &output); + + Tensor expected_output(GetCPUAllocator(), DataType::DT_BFLOAT16); + expected_output.Resize({batch, rows, cols}); + std::unique_ptr gemm_ref = delegator::Gemm::Create( + context.workspace(), + MACE_DELEGATOR_KEY(Gemm, DeviceType::CPU, BFloat16, ImplType::REF), + delegator::GemmParam()); + gemm_ref->Compute(&context, &lhs, &rhs, batch, rows, cols, depth, lhs_major, + rhs_major, output_major, lhs_batched, rhs_batched, + &expected_output); + + ExpectTensorSimilar(expected_output, output, 1e-4); +} + +TEST(ArmGemm, TestGemmBF16) { + TestGemmBFloat16(1, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true); + TestGemmBFloat16(1, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true); + + TestGemmBFloat16(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, RowMajor, RowMajor, ColMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, RowMajor, ColMajor, RowMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, RowMajor, ColMajor, ColMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, ColMajor, RowMajor, RowMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, ColMajor, RowMajor, ColMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, ColMajor, ColMajor, RowMajor, true, true); + TestGemmBFloat16(3, 47, 69, 37, ColMajor, ColMajor, ColMajor, true, true); + + TestGemmBFloat16(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, true, false); + TestGemmBFloat16(3, 47, 69, 37, RowMajor, RowMajor, RowMajor, false, true); + + TestGemmBFloat16(16, 31, 61, 67, RowMajor, ColMajor, RowMajor, true, true); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/test/ccunit/mace/ops/bias_add_test.cc b/test/ccunit/mace/ops/bias_add_test.cc index a1d37e00..45ebed95 100644 --- a/test/ccunit/mace/ops/bias_add_test.cc +++ b/test/ccunit/mace/ops/bias_add_test.cc @@ -309,6 +309,50 @@ TEST_F(BiasAddOpTest, Quantized) { TestQuantized(true, true); } +#ifdef MACE_ENABLE_BFLOAT16 +TEST_F(BiasAddOpTest, BFloat16) { + // generate random input + static unsigned int seed = time(NULL); + index_t batch = 1 + rand_r(&seed) % 10; + index_t channels = 3 + rand_r(&seed) % 50; + index_t height = 103 + rand_r(&seed) % 100; + index_t width = 113 + rand_r(&seed) % 100; + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", + {batch, channels, height, width}); + net.AddRandomInput("Bias", {channels}, true); + + net.Cast("Input", "BF16Input"); + net.Cast("Bias", "BF16Bias"); + + // Construct graph + OpDefBuilder("BiasAdd", "BiasAddTest") + .Input("Input") + .Input("Bias") + .AddIntArg("has_data_format", 1) + .Output("Output") + .AddIntArg("T", static_cast(DT_FLOAT)) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("BiasAdd", "BF16BiasAddTest") + .Input("BF16Input") + .Input("BF16Bias") + .AddIntArg("has_data_format", 1) + .Output("BF16Output") + .AddIntArg("T", static_cast(DT_BFLOAT16)) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + net.Cast("BF16Output", "CastOutput"); + + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("CastOutput"), 1e-5); +} +#endif // MACE_ENABLE_BFLOAT16 } // namespace test } // namespace ops diff --git a/test/ccunit/mace/ops/conv_2d_test.cc b/test/ccunit/mace/ops/conv_2d_test.cc index f9823265..8205c590 100644 --- a/test/ccunit/mace/ops/conv_2d_test.cc +++ b/test/ccunit/mace/ops/conv_2d_test.cc @@ -1367,6 +1367,74 @@ TEST_F(Conv2dOpTest, Quant) { TestQuant(1, 128, 64, 32, 32, 7, 7, SAME, {3, 3}); } +#ifdef MACE_ENABLE_BFLOAT16 +namespace { +void TestBFloat16(const index_t batch, + const index_t out_channels, + const index_t in_channels, + const index_t in_height, + const index_t in_width, + const index_t k_height, + const index_t k_width, + enum Padding padding_type, + const std::vector &strides) { + OpsTestNet net; + net.AddRandomInput("Input", + {batch, in_channels, in_height, in_width}); + net.AddRandomInput( + "Filter", {out_channels, in_channels, k_height, k_width}, true); + net.AddRandomInput("Bias", {out_channels}, true); + net.Cast("Input", "BF16Input"); + net.Cast("Filter", "BF16Filter"); + net.Cast("Bias", "BF16Bias"); + + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_FLOAT)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + OpDefBuilder("Conv2D", "BF16Conv2dTest") + .Input("BF16Input") + .Input("BF16Filter") + .Input("BF16Bias") + .Output("BF16Output") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_BFLOAT16)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + net.Cast("BF16Output", "CastOutput"); + + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("CastOutput"), 1e-4); +} +} // namespace + +TEST_F(Conv2dOpTest, BFloat16) { + TestBFloat16(1, 128, 64, 32, 32, 1, 1, VALID, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 3, 3, VALID, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 3, 3, SAME, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 3, 3, FULL, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 3, 3, SAME, {2, 2}); + TestBFloat16(1, 129, 63, 33, 31, 3, 3, SAME, {1, 1}); + TestBFloat16(9, 128, 64, 32, 32, 3, 3, SAME, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 1, 5, SAME, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 5, 5, SAME, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 5, 1, SAME, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {1, 1}); + TestBFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {2, 2}); + TestBFloat16(1, 128, 64, 32, 32, 7, 7, SAME, {3, 3}); +} +#endif // MACE_ENABLE_BFLOAT16 } // namespace test } // namespace ops } // namespace mace diff --git a/test/ccunit/mace/ops/depthwise_conv2d_test.cc b/test/ccunit/mace/ops/depthwise_conv2d_test.cc index a91d7961..ae1bac04 100644 --- a/test/ccunit/mace/ops/depthwise_conv2d_test.cc +++ b/test/ccunit/mace/ops/depthwise_conv2d_test.cc @@ -492,6 +492,71 @@ TEST_F(DepthwiseConv2dOpTest, Quant) { TestQuant(3, 1, 128, 56, 56, 3, 3, SAME, {2, 2}); } +namespace { +void TestBFloat16(const index_t batch, + const index_t multiplier, + const index_t in_channels, + const index_t in_height, + const index_t in_width, + const index_t k_height, + const index_t k_width, + enum Padding padding_type, + const std::vector &strides) { + OpsTestNet net; + const index_t out_channels = multiplier * in_channels; + net.AddRandomInput( + "Input", {batch, in_channels, in_height, in_width}, false, false); + net.AddRandomInput( + "Filter", {multiplier, in_channels, k_height, k_width}, true, false); + net.AddRandomInput("Bias", {out_channels}, true); + net.Cast("Input", "BF16Input"); + net.Cast("Filter", "BF16Filter"); + net.Cast("Bias", "BF16Bias"); + + OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_FLOAT)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + OpDefBuilder("DepthwiseConv2d", "BF16DepthwiseConv2DTest") + .Input("BF16Input") + .Input("BF16Filter") + .Input("BF16Bias") + .Output("BF16Output") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_BFLOAT16)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + + net.Cast("BF16Output", "CastOutput"); + + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("CastOutput"), 1e-4); +} +} // namespace + +TEST_F(DepthwiseConv2dOpTest, BFloat16) { + TestBFloat16(1, 1, 1024, 7, 7, 3, 3, VALID, {1, 1}); + TestBFloat16(1, 1, 1024, 7, 7, 3, 3, SAME, {1, 1}); + TestBFloat16(1, 1, 1024, 7, 7, 3, 3, FULL, {1, 1}); + TestBFloat16(1, 2, 1024, 7, 7, 3, 3, SAME, {1, 1}); + TestBFloat16(1, 2, 1024, 7, 7, 3, 3, SAME, {2, 2}); + TestBFloat16(1, 1, 512, 14, 14, 3, 3, SAME, {1, 1}); + TestBFloat16(1, 1, 512, 14, 13, 5, 5, SAME, {2, 2}); + TestBFloat16(1, 1, 256, 28, 28, 3, 3, SAME, {1, 1}); + TestBFloat16(1, 1, 128, 56, 56, 3, 3, SAME, {2, 2}); + TestBFloat16(3, 1, 128, 56, 56, 3, 3, SAME, {2, 2}); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/test/ccutils/mace/ops/ops_test_util.h b/test/ccutils/mace/ops/ops_test_util.h index e1e56342..94972ce0 100644 --- a/test/ccutils/mace/ops/ops_test_util.h +++ b/test/ccutils/mace/ops/ops_test_util.h @@ -377,6 +377,22 @@ class OpsTestNet { } } + template + void Cast(const std::string &src_name, const std::string &dst_name) { + Tensor *input = ws_.GetTensor(src_name); + Tensor *output = ws_.CreateTensor( + dst_name, OpTestContext::Get()->GetDevice(D)->allocator(), + DataTypeToEnum::v(), input->is_weight()); + output->Resize(input->shape()); + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + auto input_data = input->data(); + auto output_data = output->mutable_data(); + for (index_t i = 0; i < input->size(); ++i) { + output_data[i] = input_data[i]; + } + } + // Create standalone tensor on device D with T type. template std::unique_ptr CreateTensor( diff --git a/tools/bazel_adb_run.py b/tools/bazel_adb_run.py index a0bed249..7526314b 100644 --- a/tools/bazel_adb_run.py +++ b/tools/bazel_adb_run.py @@ -95,6 +95,11 @@ def parse_args(): type=str2bool, default=True, help="Whether to use quantization ops") + parser.add_argument( + "--enable_bfloat16", + type=str2bool, + default=True, + help="Whether to use bfloat16") parser.add_argument( "--enable_rpcmem", type=str2bool, @@ -174,6 +179,7 @@ def main(unused_args): toolchain=toolchain, enable_neon=FLAGS.enable_neon, enable_quantize=FLAGS.enable_quantize, + enable_bfloat16=FLAGS.enable_bfloat16, enable_rpcmem=FLAGS.enable_rpcmem, enable_hta=FLAGS.enable_hta, address_sanitizer=FLAGS.address_sanitizer, diff --git a/tools/bazel_build_standalone_lib.sh b/tools/bazel_build_standalone_lib.sh index 6e92af24..c1a67a36 100755 --- a/tools/bazel_build_standalone_lib.sh +++ b/tools/bazel_build_standalone_lib.sh @@ -16,23 +16,23 @@ declare -r NORMAL=$(tput sgr0) helper() { echo -e "usage:\t$0 ["${BOLD}"--abi"${NORMAL}"=abi]\ ["${BOLD}"--runtimes"${NORMAL}"=rt1,rt2,...]["${BOLD}"--static"${NORMAL}"]" - + echo -e "\t"${BOLD}"--abi:"${NORMAL}" specifies the targeted ABI, supported \ ABIs are:\n\t\tarmeabi-v7a, arm64-v8a, arm_linux_gnueabihf, aarch64_linux_gnu \ or \n\t\thost if the library is built for the host machine (linux-x86-64).\n\t\ \tThe default ABI is arm64-v8a." - + echo -e "\t"${BOLD}"--runtimes:"${NORMAL}" specifies the runtimes, supported \ runtimes are:\n\t\tcpu, gpu, dsp, apu, hta. By default, the library is built to\ run on CPU." - + echo -e "\t"${BOLD}"--static:"${NORMAL}" option to generate the corresponding\ static library.\n\t\tIf the option is omitted, a shared library is built." - + exit 0 } -# configuration variables +# default configuration variables abi=arm64-v8a enable_neon=true enable_hta=false @@ -41,6 +41,7 @@ enable_gpu=false enable_dsp=false enable_apu=false enable_quantize=true +enable_bfloat16=true enable_rpcmem=true static_lib=false symbol_hidden= @@ -139,7 +140,8 @@ case "${abi}" in --config optimization mace/libmace:libmace_"${lib_type}"\ --define neon="${enable_neon}" \ --define opencl="${enable_gpu}" \ - --define quantize="${enable_quantize}" + --define quantize="${enable_quantize}" \ + --define bfloat16="${enable_bfloat16}" ;; linux-x86-64) bazel build mace/libmace:libmace_"${lib_type}" --config linux --config \ @@ -152,7 +154,8 @@ case "${abi}" in --define neon="${enable_neon}" --define hta="${enable_hta}" \ --define opencl="${enable_gpu}" --define apu="${enable_apu}" \ --define hexagon="${enable_dsp}" --define quantize="${enable_quantize}" \ - --cpu="${abi}" --define rpcmem="${enable_rpcmem}" + --define rpcmem="${enable_rpcmem}" --define bfloat16="${enable_bfloat16}" \ + --cpu="${abi}" if [[ "${enable_dsp}" == true ]];then cp third_party/nnlib/"${abi}"/libhexagon_controller.so \ "${LIB_DIR}"/"${abi}"/ -- GitLab