diff --git a/mace/kernels/addn.h b/mace/kernels/addn.h index af2e3542d98e80591cbbd35798e041bf3b31b741..3d978d197da8acf9f2d606dbcba95f186f93556b 100644 --- a/mace/kernels/addn.h +++ b/mace/kernels/addn.h @@ -23,6 +23,11 @@ struct AddNFunctor { } }; +template <> +void AddNFunctor::operator()(const vector& inputs, + float *output, + index_t size); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 84ca48d4a76bc477258ce0d9ec152d5f313709a9..0c1c2ef0b72a317091e6358f5255755953050d2d 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -11,19 +11,12 @@ namespace mace { namespace kernels { -template -struct BatchNormFunctorBase { - BatchNormFunctorBase(const float variance_epsilon) - :variance_epsilon_(variance_epsilon){} - +template +struct BatchNormFunctor { float variance_epsilon_; -}; - -template -struct BatchNormFunctor : public BatchNormFunctorBase { BatchNormFunctor(const float variance_epsilon) - :BatchNormFunctorBase(variance_epsilon){} + : variance_epsilon_(variance_epsilon){} void operator()(const T* input, const T* scale, @@ -57,9 +50,20 @@ struct BatchNormFunctor : public BatchNormFunctorBase { } } } - }; +template <> +void BatchNormFunctor::operator()(const float* input, + const float* scale, + const float* offset, + const float* mean, + const float* var, + const index_t n, + const index_t channel, + const index_t sample_size, + float* output); + + } // namepsace kernels } // namespace mace diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index a306fdbc804e0c5995846fa89dd5bb681d31e1ed..61cbed1ae3166301b8cddacbe59f0d0e418fc868 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -2,28 +2,22 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#if __ARM_NEON #include #include "mace/kernels/batch_norm.h" namespace mace { namespace kernels { -template -struct BatchNormFunctor : public BatchNormFunctorBase { - BatchNormFunctor(const float variance_epsilon) - :BatchNormFunctorBase(variance_epsilon){} - - void operator()(const T* input, - const T* scale, - const T* offset, - const T* mean, - const T* var, - const int n, - const int channel, - const int sample_size, - T* output) { - +template <> +void BatchNormFunctor::operator()(const float* input, + const float* scale, + const float* offset, + const float* mean, + const float* var, + const index_t n, + const index_t channel, + const index_t sample_size, + float* output) { // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // The calculation formula for inference is // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + @@ -31,39 +25,37 @@ struct BatchNormFunctor : public BatchNormFunctorBase> 2; - int remain_count = sample_size - count; - for (index_t c = 0; c < channel; ++c) { - new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); - new_offset = offset[c] - mean[c] * new_scale; - index_t pos = c * sample_size; - - float32x4_t new_scale_f = vdupq_n_f32(new_scale); - float32x4_t new_offset_f = vdupq_n_f32(new_offset); - for (index_t i = 0; i < n; ++i) { - const float* input_sample_ptr = input + pos; - float* output_sample_ptr = output + pos; - - for(index_t j = 0; j < count; ++j) { - float32x4_t input_f = vld1q_f32(input_sample_ptr); - float32x4_t output_f = new_offset_f; - output_f = vfmaq_f32(output_f, input_f, new_scale_f); - vst1q_f32(output_sample_ptr, output_f); - input_sample_ptr += 4; - output_sample_ptr += 4; - } - for(index_t j = 0; j < remain_count; ++j) { - *output_sample_ptr = new_scale * *input_sample_ptr + new_offset; - ++output_sample_ptr; - ++input_sample_ptr; - } - pos += channel * sample_size; + float new_scale, new_offset; + int count = sample_size >> 2; + int remain_count = sample_size - count; + for (index_t c = 0; c < channel; ++c) { + new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); + new_offset = offset[c] - mean[c] * new_scale; + index_t pos = c * sample_size; + + float32x4_t new_scale_f = vdupq_n_f32(new_scale); + float32x4_t new_offset_f = vdupq_n_f32(new_offset); + for (index_t i = 0; i < n; ++i) { + const float *input_sample_ptr = input + pos; + float *output_sample_ptr = output + pos; + + for (index_t j = 0; j < count; ++j) { + float32x4_t input_f = vld1q_f32(input_sample_ptr); + float32x4_t output_f = new_offset_f; + output_f = vfmaq_f32(output_f, input_f, new_scale_f); + vst1q_f32(output_sample_ptr, output_f); + input_sample_ptr += 4; + output_sample_ptr += 4; + } + for (index_t j = 0; j < remain_count; ++j) { + *output_sample_ptr = new_scale * *input_sample_ptr + new_offset; + ++output_sample_ptr; + ++input_sample_ptr; } + pos += channel * sample_size; } } }; } // namespace kernels -} // namespace mace -#endif // __ARM_NEON +} // namespace mace \ No newline at end of file diff --git a/mace/kernels/relu.h b/mace/kernels/relu.h index fd845c1fd6acd8722e340b7cb7ae1aa004a97f1e..8eed29a9839c7628bc0944890489c5512e32bfd3 100644 --- a/mace/kernels/relu.h +++ b/mace/kernels/relu.h @@ -19,6 +19,11 @@ struct ReluFunctor { } }; +template <> +void ReluFunctor::operator()(const float *input, + float *output, + index_t size); + } // namespace kernels } // namespace mace diff --git a/mace/kernels/test/addn_neon_test.cc b/mace/kernels/test/addn_neon_test.cc index c6425595745e9eba9c64c4f7767185926863de56..6aebb9013dd7a8fdd00d3c3a52b3b9368ca389a6 100644 --- a/mace/kernels/test/addn_neon_test.cc +++ b/mace/kernels/test/addn_neon_test.cc @@ -10,6 +10,7 @@ using namespace mace; using namespace mace::kernels; TEST(NeonTest, AddN) { + testing::internal::LogToStderr(); std::random_device rd; std::mt19937 gen(rd()); std::normal_distribution nd(0, 1); diff --git a/mace/kernels/test/relu_neon_test.cc b/mace/kernels/test/relu_neon_test.cc index 2e98b62a3b461099aa2dabbce68b8e4637eebe88..d5200ff127bb66ae7ba1d477e3f77ed02079b37f 100644 --- a/mace/kernels/test/relu_neon_test.cc +++ b/mace/kernels/test/relu_neon_test.cc @@ -10,6 +10,7 @@ using namespace mace; using namespace mace::kernels; TEST(NeonTest, Relu) { + testing::internal::LogToStderr(); std::random_device rd; std::mt19937 gen(rd()); std::normal_distribution nd(0, 1);