提交 292f80aa 编写于 作者: L Liangliang He

Update BatchNorm CPU kernel

上级 4f407e70
......@@ -14,9 +14,7 @@ cc_library(
srcs = glob([
"*.cc",
"opencl/*.cc",
]) + if_neon_enabled(glob([
"neon/batch_norm_neon.cc",
])),
]),
hdrs = glob([
"*.h",
"opencl/*.h",
......
......@@ -5,11 +5,15 @@
#ifndef MACE_KERNELS_BATCH_NORM_H_
#define MACE_KERNELS_BATCH_NORM_H_
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
#include <arm_neon.h>
#endif
#include "mace/core/future.h"
#include "mace/core/public/mace.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
#include "mace/kernels/activation.h"
#include "mace/core/runtime/opencl/cl2_header.h"
namespace mace {
namespace kernels {
......@@ -86,17 +90,44 @@ struct BatchNormFunctor : BatchNormFunctorBase {
}
}
const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
#pragma omp parallel for collapse(4)
for (index_t n = 0; n < batch; ++n) {
for (index_t h = 0; h < height; ++h) {
for (index_t w = 0; w < width; ++w) {
for (index_t c = 0; c < channels; ++c) {
index_t pos = (((n * height) + h) * width + w) * channels + c;
const T *scale_data = folded_constant_ ? scale_ptr : new_scale.data();
const T *offset_data = folded_constant_ ? offset_ptr : new_offset.data();
const int elements = batch * height * width;
constexpr int c_tile_size = 4;
const int c_tiles = channels / c_tile_size;
const index_t remains_start = c_tiles * c_tile_size;
if (c_tiles > 0) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (int cb = 0; cb < c_tiles; ++cb) {
#if defined(MACE_ENABLE_NEON) && defined(__aarch64__)
static_assert(c_tile_size == 4, "channels tile size must be 4");
int c = cb * c_tile_size;
int pos = i * channels + c;
float32x4_t scales = vld1q_f32(scale_data + c);
float32x4_t offsets = vld1q_f32(offset_data + c);
float32x4_t in = vld1q_f32(input_ptr + pos);
float32x4_t out = vfmaq_f32(offsets, scales, in);
vst1q_f32(output_ptr + pos, out);
#else
for (int ci = 0; ci < c_tile_size; ++ci) {
int c = cb * c_tile_size + ci;
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
}
#endif
}
}
}
if (remains_start < channels) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < elements; ++i) {
for (index_t c = remains_start; c < channels; ++c) {
index_t pos = i * channels + c;
output_ptr[pos] = scale_data[c] * input_ptr[pos] + offset_data[c];
}
}
}
......
......@@ -13,14 +13,6 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
.Build(),
BatchNormOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::NEON)
.TypeConstraint<float>("T")
.Build(),
BatchNormOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("BatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
......@@ -34,4 +26,4 @@ void Register_BatchNorm(OperatorRegistry *op_registry) {
BatchNormOp<DeviceType::OPENCL, half>);
}
} // namespace mace
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册