提交 a3abf2f4 编写于 作者: L liuqi

Optimize the address calculation for batch_norm.

上级 6772190f
...@@ -45,14 +45,15 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> { ...@@ -45,14 +45,15 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
for (TIndex c = 0; c < channel; ++c) { for (TIndex c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_); new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_offset = offset[c] - mean[c] * new_scale; new_offset = offset[c] - mean[c] * new_scale;
TIndex pos = c * sample_size;
for (TIndex i = 0; i < n; ++i) { for (TIndex i = 0; i < n; ++i) {
TIndex pos = (i * channel + c) * sample_size;
const T* input_sample_ptr = input + pos; const T* input_sample_ptr = input + pos;
T* output_sample_ptr = output + pos; T* output_sample_ptr = output + pos;
for (TIndex j = 0; j < sample_size; ++j) { for (TIndex j = 0; j < sample_size; ++j) {
output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset; output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset;
} }
pos += channel * sample_size;
} }
} }
} }
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
//#if __ARM_NEON #if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#include "mace/kernels/batch_norm.h" #include "mace/kernels/batch_norm.h"
...@@ -10,7 +10,7 @@ namespace mace { ...@@ -10,7 +10,7 @@ namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <typename T>
struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceType::NEON, T> { struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<DeviceType::NEON, T> {
BatchNormFunctor(const float variance_epsilon) BatchNormFunctor(const float variance_epsilon)
:BatchNormFunctorBase<DeviceType::NEON, T>(variance_epsilon){} :BatchNormFunctorBase<DeviceType::NEON, T>(variance_epsilon){}
...@@ -35,13 +35,13 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy ...@@ -35,13 +35,13 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
int count = sample_size >> 2; int count = sample_size >> 2;
int remain_count = sample_size - count; int remain_count = sample_size - count;
for (TIndex c = 0; c < channel; ++c) { for (TIndex c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon_); new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_offset = offset[c] - mean[c] * new_scale; new_offset = offset[c] - mean[c] * new_scale;
TIndex pos = c * sample_size;
float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset); float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (TIndex i = 0; i < n; ++i) { for (TIndex i = 0; i < n; ++i) {
TIndex pos = (i * channel + c) * sample_size;
const float* input_sample_ptr = input + pos; const float* input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos; float* output_sample_ptr = output + pos;
...@@ -58,6 +58,7 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy ...@@ -58,6 +58,7 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
++output_sample_ptr; ++output_sample_ptr;
++input_sample_ptr; ++input_sample_ptr;
} }
pos += channel * sample_size;
} }
} }
} }
...@@ -65,4 +66,4 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy ...@@ -65,4 +66,4 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
//#endif // __ARM_NEON #endif // __ARM_NEON
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册