提交 3f803f84 编写于 作者: 刘琦

Merge branch 'bn' into 'master'

Update BatchNorm CPU kernel

See merge request !238
...@@ -14,9 +14,7 @@ cc_library( ...@@ -14,9 +14,7 @@ cc_library(
srcs = glob([ srcs = glob([
"*.cc", "*.cc",
"opencl/*.cc", "opencl/*.cc",
]) + if_neon_enabled(glob([ ]),
"neon/batch_norm_neon.cc",
])),
hdrs = glob([ hdrs = glob([
"*.h", "*.h",
"opencl/*.h", "opencl/*.h",
......
...@@ -5,11 +5,15 @@ ...@@ -5,11 +5,15 @@
#ifndef MACE_KERNELS_BATCH_NORM_H_ #ifndef MACE_KERNELS_BATCH_NORM_H_
#define MACE_KERNELS_BATCH_NORM_H_ #define MACE_KERNELS_BATCH_NORM_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/public/mace.h" #include "mace/core/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/kernels/activation.h" #include "mace/kernels/activation.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
...@@ -89,14 +93,41 @@ struct BatchNormFunctor : BatchNormFunctorBase { ...@@ -89,14 +93,41 @@ struct BatchNormFunctor : BatchNormFunctorBase {
const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data(); const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data(); const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
#pragma omp parallel for collapse(4) const int elements = batch * height * width;
for (index_t n = 0; n < batch; ++n) { constexpr int c_tile_size = 4;
for (index_t h = 0; h < height; ++h) { const int c_tiles = channels / c_tile_size;
for (index_t w = 0; w < width; ++w) { const index_t remains_start = c_tiles * c_tile_size;
for (index_t c = 0; c < channels; ++c) {
index_t pos = (((n * height) + h) * width + w) * channels + c; if (c_tiles > 0) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (int cb = 0; cb < c_tiles; ++cb) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(c_tile_size == 4, "channels tile size must be 4");
int c = cb * c_tile_size;
int pos = i * channels + c;
float32x4_t scales = vld1q_f32(scale_data + c);
float32x4_t offsets = vld1q_f32(offset_data + c);
float32x4_t in = vld1q_f32(input_ptr + pos);
float32x4_t out = vfmaq_f32(offsets, scales, in);
vst1q_f32(output_ptr + pos, out);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
int c = cb * c_tile_size + ci;
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c]; output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
} }
#endif
}
}
}
if (remains_start < channels) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (index_t c = remains_start; c < channels; ++c) {
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
} }
} }
} }
......
...@@ -13,14 +13,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) { ...@@ -13,14 +13,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
.Build(), .Build(),
BatchNormOp<DeviceType::CPU, float>); BatchNormOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
BatchNormOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm") REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册