提交 a3abf2f4 编写于 作者: L liuqi

Optimize the address calculation for batch_norm.

上级 6772190f
......@@ -45,14 +45,15 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
for (TIndex c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_offset = offset[c] - mean[c] * new_scale;
TIndex pos = c * sample_size;
for (TIndex i = 0; i < n; ++i) {
TIndex pos = (i * channel + c) * sample_size;
const T* input_sample_ptr = input + pos;
T* output_sample_ptr = output + pos;
for (TIndex j = 0; j < sample_size; ++j) {
output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset;
}
pos += channel * sample_size;
}
}
}
......
......@@ -2,7 +2,7 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
//#if __ARM_NEON
#if __ARM_NEON
#include <arm_neon.h>
#include "mace/kernels/batch_norm.h"
......@@ -10,7 +10,7 @@ namespace mace {
namespace kernels {
template <typename T>
struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceType::NEON, T> {
struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<DeviceType::NEON, T> {
BatchNormFunctor(const float variance_epsilon)
:BatchNormFunctorBase<DeviceType::NEON, T>(variance_epsilon){}
......@@ -35,13 +35,13 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
int count = sample_size >> 2;
int remain_count = sample_size - count;
for (TIndex c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon_);
new_scale = scale[c] / std::sqrt(var[c] + this->variance_epsilon_);
new_offset = offset[c] - mean[c] * new_scale;
TIndex pos = c * sample_size;
float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (TIndex i = 0; i < n; ++i) {
TIndex pos = (i * channel + c) * sample_size;
const float* input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos;
......@@ -58,6 +58,7 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
++output_sample_ptr;
++input_sample_ptr;
}
pos += channel * sample_size;
}
}
}
......@@ -65,4 +66,4 @@ struct BatchNormFunctor<DeviceType::NEON> : public BatchNormFunctorBase<DeviceTy
} // namespace kernels
} // namespace mace
//#endif // __ARM_NEON
#endif // __ARM_NEON
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册