提交 becc921f 编写于 作者: 李寅 提交者: wuchenghui

Declaire Neon Specialization

上级 94dbdeaf
...@@ -23,6 +23,11 @@ struct AddNFunctor { ...@@ -23,6 +23,11 @@ struct AddNFunctor {
} }
}; };
template <>
void AddNFunctor<DeviceType::NEON, float>::operator()(const vector<const float*>& inputs,
float *output,
index_t size);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -11,19 +11,12 @@ ...@@ -11,19 +11,12 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template<DeviceType D, typename T>
struct BatchNormFunctorBase { struct BatchNormFunctor {
BatchNormFunctorBase(const float variance_epsilon)
:variance_epsilon_(variance_epsilon){}
float variance_epsilon_; float variance_epsilon_;
};
template<DeviceType D, typename T>
struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
BatchNormFunctor(const float variance_epsilon) BatchNormFunctor(const float variance_epsilon)
:BatchNormFunctorBase<D, T>(variance_epsilon){} : variance_epsilon_(variance_epsilon){}
void operator()(const T* input, void operator()(const T* input,
const T* scale, const T* scale,
...@@ -57,9 +50,20 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> { ...@@ -57,9 +50,20 @@ struct BatchNormFunctor : public BatchNormFunctorBase<D, T> {
} }
} }
} }
}; };
template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
const float* scale,
const float* offset,
const float* mean,
const float* var,
const index_t n,
const index_t channel,
const index_t sample_size,
float* output);
} // namepsace kernels } // namepsace kernels
} // namespace mace } // namespace mace
......
...@@ -2,28 +2,22 @@ ...@@ -2,28 +2,22 @@
// Copyright (c) 2017 XiaoMi All rights reserved. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#if __ARM_NEON
#include <arm_neon.h> #include <arm_neon.h>
#include "mace/kernels/batch_norm.h" #include "mace/kernels/batch_norm.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <>
struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<DeviceType::NEON, T> { void BatchNormFunctor<DeviceType::NEON, float>::operator()(const float* input,
BatchNormFunctor(const float variance_epsilon) const float* scale,
:BatchNormFunctorBase<DeviceType::NEON, T>(variance_epsilon){} const float* offset,
const float* mean,
void operator()(const T* input, const float* var,
const T* scale, const index_t n,
const T* offset, const index_t channel,
const T* mean, const index_t sample_size,
const T* var, float* output) {
const int n,
const int channel,
const int sample_size,
T* output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is // The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
...@@ -31,7 +25,7 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic ...@@ -31,7 +25,7 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
T new_scale, new_offset; float new_scale, new_offset;
int count = sample_size >> 2; int count = sample_size >> 2;
int remain_count = sample_size - count; int remain_count = sample_size - count;
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
...@@ -42,10 +36,10 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic ...@@ -42,10 +36,10 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic
float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset); float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (index_t i = 0; i < n; ++i) { for (index_t i = 0; i < n; ++i) {
const float* input_sample_ptr = input + pos; const float *input_sample_ptr = input + pos;
float* output_sample_ptr = output + pos; float *output_sample_ptr = output + pos;
for(index_t j = 0; j < count; ++j) { for (index_t j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr); float32x4_t input_f = vld1q_f32(input_sample_ptr);
float32x4_t output_f = new_offset_f; float32x4_t output_f = new_offset_f;
output_f = vfmaq_f32(output_f, input_f, new_scale_f); output_f = vfmaq_f32(output_f, input_f, new_scale_f);
...@@ -53,7 +47,7 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic ...@@ -53,7 +47,7 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic
input_sample_ptr += 4; input_sample_ptr += 4;
output_sample_ptr += 4; output_sample_ptr += 4;
} }
for(index_t j = 0; j < remain_count; ++j) { for (index_t j = 0; j < remain_count; ++j) {
*output_sample_ptr = new_scale * *input_sample_ptr + new_offset; *output_sample_ptr = new_scale * *input_sample_ptr + new_offset;
++output_sample_ptr; ++output_sample_ptr;
++input_sample_ptr; ++input_sample_ptr;
...@@ -61,9 +55,7 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic ...@@ -61,9 +55,7 @@ struct BatchNormFunctor<DeviceType::NEON, T> : public BatchNormFunctorBase<Devic
pos += channel * sample_size; pos += channel * sample_size;
} }
} }
}
}; };
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
\ No newline at end of file
#endif // __ARM_NEON
...@@ -19,6 +19,11 @@ struct ReluFunctor { ...@@ -19,6 +19,11 @@ struct ReluFunctor {
} }
}; };
template <>
void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input,
float *output,
index_t size);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -10,6 +10,7 @@ using namespace mace; ...@@ -10,6 +10,7 @@ using namespace mace;
using namespace mace::kernels; using namespace mace::kernels;
TEST(NeonTest, AddN) { TEST(NeonTest, AddN) {
testing::internal::LogToStderr();
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
......
...@@ -10,6 +10,7 @@ using namespace mace; ...@@ -10,6 +10,7 @@ using namespace mace;
using namespace mace::kernels; using namespace mace::kernels;
TEST(NeonTest, Relu) { TEST(NeonTest, Relu) {
testing::internal::LogToStderr();
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1); std::normal_distribution<float> nd(0, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册