diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index 54ed3fcd3f73d0a6cfab668a45dead37b88f09e8..7689f4e5c4363aac414baf201d5541af4c9bebd4 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -14,9 +14,7 @@ cc_library( srcs = glob([ "*.cc", "opencl/*.cc", - ]) + if_neon_enabled(glob([ - "neon/batch_norm_neon.cc", - ])), + ]), hdrs = glob([ "*.h", "opencl/*.h", diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 107b3242bc29a5f34b133c7412c6c20d2e9a1134..978120b24194b1b3fe91e9ebb811463b213dab27 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -5,11 +5,15 @@ #ifndef MACE_KERNELS_BATCH_NORM_H_ #define MACE_KERNELS_BATCH_NORM_H_ +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) +#include +#endif + #include "mace/core/future.h" #include "mace/core/public/mace.h" +#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/tensor.h" #include "mace/kernels/activation.h" -#include "mace/core/runtime/opencl/cl2_header.h" namespace mace { namespace kernels { @@ -86,17 +90,44 @@ struct BatchNormFunctor : BatchNormFunctorBase { } } - const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data(); - const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data(); - -#pragma omp parallel for collapse(4) - for (index_t n = 0; n < batch; ++n) { - for (index_t h = 0; h < height; ++h) { - for (index_t w = 0; w < width; ++w) { - for (index_t c = 0; c < channels; ++c) { - index_t pos = (((n * height) + h) * width + w) * channels + c; + const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data(); + const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data(); + + const int elements = batch * height * width; + constexpr int c_tile_size = 4; + const int c_tiles = channels / c_tile_size; + const index_t remains_start = c_tiles * c_tile_size; + + if (c_tiles > 0) { +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < elements; ++i) { + for (int cb = 0; cb < c_tiles; ++cb) { +#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) + static_assert(c_tile_size == 4, "channels tile size must be 4"); + int c = cb * c_tile_size; + int pos = i * channels + c; + + float32x4_t scales = vld1q_f32(scale_data + c); + float32x4_t offsets = vld1q_f32(offset_data + c); + float32x4_t in = vld1q_f32(input_ptr + pos); + float32x4_t out = vfmaq_f32(offsets, scales, in); + vst1q_f32(output_ptr + pos, out); +#else + for (int ci = 0; ci < c_tile_size; ++ci) { + int c = cb * c_tile_size + ci; + index_t pos = i * channels + c; output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c]; } +#endif + } + } + } + if (remains_start < channels) { +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < elements; ++i) { + for (index_t c = remains_start; c < channels; ++c) { + index_t pos = i * channels + c; + output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c]; } } } diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc index ade5c7c7c5d107a15bdf55e68133adb227fd9f64..e0754fee8c502f542a25ff17eb2bfc3b828885e0 100644 --- a/mace/ops/batch_norm.cc +++ b/mace/ops/batch_norm.cc @@ -13,14 +13,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) { .Build(), BatchNormOp); -#if MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm") - .Device(DeviceType::NEON) - .TypeConstraint("T") - .Build(), - BatchNormOp); -#endif // MACE_ENABLE_NEON - REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm") .Device(DeviceType::OPENCL) .TypeConstraint("T") @@ -34,4 +26,4 @@ void Register_BatchNorm(OperatorRegistry *op_registry) { BatchNormOp); } -} // namespace mace +} // namespace mace