From a3abf2f4defacffd0495c35b1e39898c8aecc706 Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 7 Sep 2017 10:58:20 +0800 Subject: [PATCH] Optimize the address calculation for batch_norm. --- mace/kernels/batch_norm.h | 3 ++- mace/kernels/neon/batch_norm_neon.cc | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 136163f9..d2899d76 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -45,14 +45,15 @@ struct BatchNormFunctor : public BatchNormFunctorBase { for (TIndex c = 0; c < channel; ++c) { new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); new_offset = offset[c] - mean[c] * new_scale; + TIndex pos = c * sample_size; for (TIndex i = 0; i < n; ++i) { - TIndex pos = (i * channel + c) * sample_size; const T* input_sample_ptr = input + pos; T* output_sample_ptr = output + pos; for (TIndex j = 0; j < sample_size; ++j) { output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset; } + pos += channel * sample_size; } } } diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index 7121c23d..9db63f68 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -2,7 +2,7 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -//#if __ARM_NEON +#if __ARM_NEON #include #include "mace/kernels/batch_norm.h" @@ -10,7 +10,7 @@ namespace mace { namespace kernels { template -struct BatchNormFunctor : public BatchNormFunctorBase { +struct BatchNormFunctor : public BatchNormFunctorBase { BatchNormFunctor(const float variance_epsilon) :BatchNormFunctorBase(variance_epsilon){} @@ -35,13 +35,13 @@ struct BatchNormFunctor : public BatchNormFunctorBase> 2; int remain_count = sample_size - count; for (TIndex c = 0; c < channel; ++c) { - new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon_); + new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); new_offset = offset[c] - mean[c] * new_scale; + TIndex pos = c * sample_size; float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_offset_f = vdupq_n_f32(new_offset); for (TIndex i = 0; i < n; ++i) { - TIndex pos = (i * channel + c) * sample_size; const float* input_sample_ptr = input + pos; float* output_sample_ptr = output + pos; @@ -58,6 +58,7 @@ struct BatchNormFunctor : public BatchNormFunctorBase : public BatchNormFunctorBase