提交 bcec92d0 编写于 作者: L liuqi

Add opencl batch norm kernel and fix bugs.

上级 129608cc
...@@ -13,16 +13,13 @@ namespace kernels { ...@@ -13,16 +13,13 @@ namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct BatchNormFunctor { struct BatchNormFunctor {
void operator()(const T *input, void operator()(const Tensor *input,
const T *scale, const Tensor *scale,
const T *offset, const Tensor *offset,
const T *mean, const Tensor *mean,
const T *var, const Tensor *var,
const float variance_epsilon, const Tensor *epsilon,
const index_t n, Tensor *output) {
const index_t channel,
const index_t sample_size,
T *output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is // The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X +
...@@ -31,16 +28,35 @@ struct BatchNormFunctor { ...@@ -31,16 +28,35 @@ struct BatchNormFunctor {
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} }
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
T new_scale, new_offset; const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard scale_mapper(scale);
Tensor::MappingGuard offset_mapper(offset);
Tensor::MappingGuard mean_mapper(mean);
Tensor::MappingGuard var_mapper(var);
Tensor::MappingGuard epsilon_mapper(epsilon);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *scale_ptr = scale->data<T>();
const T *offset_ptr = offset->data<T>();
const T *mean_ptr = mean->data<T>();
const T *var_ptr = var->data<T>();
const T *epsilon_ptr = epsilon->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for #pragma omp parallel for
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); T new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
new_offset = offset[c] - mean[c] * new_scale; T new_offset = offset_ptr[c] - mean_ptr[c] * new_scale;
index_t pos = c * sample_size; index_t pos = c * sample_size;
for (index_t i = 0; i < n; ++i) { for (index_t i = 0; i < n; ++i) {
const T *input_sample_ptr = input + pos; const T *input_sample_ptr = input_ptr + pos;
T *output_sample_ptr = output + pos; T *output_sample_ptr = output_ptr + pos;
for (index_t j = 0; j < sample_size; ++j) { for (index_t j = 0; j < sample_size; ++j) {
output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset; output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset;
} }
...@@ -52,16 +68,23 @@ struct BatchNormFunctor { ...@@ -52,16 +68,23 @@ struct BatchNormFunctor {
template <> template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()( void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float *input, const Tensor *input,
const float *scale, const Tensor *scale,
const float *offset, const Tensor *offset,
const float *mean, const Tensor *mean,
const float *var, const Tensor *var,
const float variance_epsilon, const Tensor *epsilon,
const index_t n, Tensor *output);
const index_t channel,
const index_t sample_size, template <>
float *output); void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output);
} // namepsace kernels } // namepsace kernels
} // namespace mace } // namespace mace
......
...@@ -10,38 +10,46 @@ namespace kernels { ...@@ -10,38 +10,46 @@ namespace kernels {
template <> template <>
void BatchNormFunctor<DeviceType::NEON, float>::operator()( void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float *input, const Tensor *input,
const float *scale, const Tensor *scale,
const float *offset, const Tensor *offset,
const float *mean, const Tensor *mean,
const float *var, const Tensor *var,
const float variance_epsilon, const Tensor *epsilon,
const index_t n, Tensor *output) {
const index_t channel,
const index_t sample_size,
float *output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is // The calculation formula for inference is
// Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + // Y = \frac{ \scale } { \sqrt{var+\epsilon} } * X +
// ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} // ( \offset - \frac { \scale * mean } { \sqrt{var+\epsilon}
// } // }
// new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_scale = \frac{ \scale } { \sqrt{var+\epsilon} }
// new_offset = \offset - mean * common_val; // new_offset = \offset - mean * common_val;
// Y = new_scale * X + new_offset; // Y = new_scale * X + new_offset;
float new_scale, new_offset; const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const float *input_ptr = input->data<float>();
const float *scale_ptr = scale->data<float>();
const float *offset_ptr = offset->data<float>();
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
const float *epsilon_ptr = epsilon->data<float>();
float *output_ptr = output->mutable_data<float>();
index_t count = sample_size >> 2; index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2); index_t remain_count = sample_size - (count << 2);
#pragma omp parallel for #pragma omp parallel for
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); float new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
new_offset = offset[c] - mean[c] * new_scale; float new_offset = offset_ptr[c] - mean_ptr[c] * new_scale;
index_t pos = c * sample_size; index_t pos = c * sample_size;
float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_scale_f = vdupq_n_f32(new_scale);
float32x4_t new_offset_f = vdupq_n_f32(new_offset); float32x4_t new_offset_f = vdupq_n_f32(new_offset);
for (index_t i = 0; i < n; ++i) { for (index_t i = 0; i < n; ++i) {
const float *input_sample_ptr = input + pos; const float *input_sample_ptr = input_ptr + pos;
float *output_sample_ptr = output + pos; float *output_sample_ptr = output_ptr + pos;
for (index_t j = 0; j < count; ++j) { for (index_t j = 0; j < count; ++j) {
float32x4_t input_f = vld1q_f32(input_sample_ptr); float32x4_t input_f = vld1q_f32(input_sample_ptr);
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/batch_norm.h"
#include "mace/core/runtime/opencl/cl2.hpp"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
namespace kernels {
template <>
void BatchNormFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output) {
const index_t n = input->dim(0);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto batch_norm_kernel =
cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer,
cl::Buffer, cl::Buffer, cl::Buffer,
int, int, cl::Buffer>(program, "batch_norm");
cl_int error;
auto res_event = batch_norm_kernel(cl::EnqueueArgs(runtime->command_queue(),
cl::NDRange(n * channel * sample_size),
cl::NDRange(128)),
*(static_cast<const cl::Buffer *>(input->buffer())),
*(static_cast<cl::Buffer *>(scale->buffer())),
*(static_cast<cl::Buffer *>(offset->buffer())),
*(static_cast<cl::Buffer *>(mean->buffer())),
*(static_cast<cl::Buffer *>(var->buffer())),
*(static_cast<cl::Buffer *>(epsilon->buffer())),
static_cast<int>(channel),
static_cast<int>(sample_size),
*(static_cast<cl::Buffer *>(output->buffer())),
error);
res_event.wait();
MACE_CHECK(error == CL_SUCCESS);
}
} // namespace kernels
} // namespace mace
\ No newline at end of file
void kernel batch_norm(global const float *input,
global const float *scale,
global const float *offset,
global const float *mean,
global const float *var,
global const float *epsilon,
private const int channels,
private const int pixels,
global float *output) {
int idx = get_global_id(0);
int channel = (idx % (channels * pixels)) / pixels;
const float *input_ptr = input + idx;
const float new_scale = scale[channel] * rsqrt(var[channel] + *epsilon);
const float new_offset = offset[channel] - mean[channel] * new_scale;
float *output_ptr = output + idx;
*output_ptr = new_scale * *input_ptr + new_offset;
}
...@@ -12,4 +12,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>); ...@@ -12,4 +12,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp<DeviceType::CPU, float>);
REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
REGISTER_OPENCL_OPERATOR(BatchNorm, BatchNormOp<DeviceType::OPENCL, float>);
} // namespace mace } // namespace mace
\ No newline at end of file
...@@ -40,20 +40,7 @@ class BatchNormOp : public Operator<D, T> { ...@@ -40,20 +40,7 @@ class BatchNormOp : public Operator<D, T> {
Tensor *output = this->Output(0); Tensor *output = this->Output(0);
output->ResizeLike(input); output->ResizeLike(input);
const index_t n = input->dim(0); functor_(input, scale, offset, mean, var, epsilon, output);
const index_t channel = input->dim(1);
const index_t sample_size = input->dim(2) * input->dim(3);
const T *input_ptr = input->data<T>();
const T *scale_ptr = scale->data<T>();
const T *offset_ptr = offset->data<T>();
const T *mean_ptr = mean->data<T>();
const T *var_ptr = var->data<T>();
const T *epsilon_ptr = epsilon->data<T>();
T *output_ptr = output->mutable_data<T>();
functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, *epsilon_ptr,
n, channel, sample_size, output_ptr);
return true; return true;
} }
......
...@@ -24,12 +24,12 @@ static void BatchNorm( ...@@ -24,12 +24,12 @@ static void BatchNorm(
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add input data // Add input data
net.AddRandomInput<DeviceType::CPU, T>("Input", {batch, channels, height, width}); net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
net.AddRandomInput<DeviceType::CPU, T>("Scale", {channels}); net.AddRandomInput<D, T>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, T>("Offset", {channels}); net.AddRandomInput<D, T>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, T>("Mean", {channels}); net.AddRandomInput<D, T>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, T>("Var", {channels}, true); net.AddRandomInput<D, T>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3}); net.AddInputFromArray<D, float>("Epsilon", {}, {1e-3});
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
...@@ -54,7 +54,8 @@ static void BatchNorm( ...@@ -54,7 +54,8 @@ static void BatchNorm(
#define BM_BATCH_NORM(N, C, H, W, TYPE) \ #define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \ BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, OPENCL);
BM_BATCH_NORM(1, 1, 512, 512, float); BM_BATCH_NORM(1, 1, 512, 512, float);
BM_BATCH_NORM(1, 3, 128, 128, float); BM_BATCH_NORM(1, 3, 128, 128, float);
......
...@@ -9,9 +9,10 @@ namespace mace { ...@@ -9,9 +9,10 @@ namespace mace {
class BatchNormOpTest : public OpsTestBase {}; class BatchNormOpTest : public OpsTestBase {};
TEST_F(BatchNormOpTest, SimpleCPU) { template <DeviceType D>
void Simple() {
// Construct graph // Construct graph
auto &net = test_net(); OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
.Input("Scale") .Input("Scale")
...@@ -23,26 +24,79 @@ TEST_F(BatchNormOpTest, SimpleCPU) { ...@@ -23,26 +24,79 @@ TEST_F(BatchNormOpTest, SimpleCPU) {
.Finalize(net.operator_def()); .Finalize(net.operator_def());
// Add input data // Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 1, 6, 2}, net.AddInputFromArray<D, float>("Input", {1, 1, 6, 2},
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
net.AddInputFromArray<DeviceType::CPU, float>("Scale", {1}, {4.0f}); net.AddInputFromArray<D, float>("Scale", {1}, {4.0f});
net.AddInputFromArray<DeviceType::CPU, float>("Offset", {1}, {2.0}); net.AddInputFromArray<D, float>("Offset", {1}, {2.0});
net.AddInputFromArray<DeviceType::CPU, float>("Mean", {1}, {10}); net.AddInputFromArray<D, float>("Mean", {1}, {10});
net.AddInputFromArray<DeviceType::CPU, float>("Var", {1}, {11.67f}); net.AddInputFromArray<D, float>("Var", {1}, {11.67f});
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3}); net.AddInputFromArray<D, float>("Epsilon", {}, {1e-3});
// Run // Run
net.RunOp(); net.RunOp(D);
// Check // Check
auto expected = auto expected =
CreateTensor<float>({1, 1, 6, 2}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, CreateTensor<float>({1, 1, 6, 2}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83,
3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); 3.17, 3.17, 5.51, 5.51, 7.86, 7.86});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.01); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(BatchNormOpTest, SimpleCPU) {
Simple<DeviceType::CPU>();
}
TEST_F(BatchNormOpTest, SimpleNEON) {
Simple<DeviceType::NEON>();
}
TEST_F(BatchNormOpTest, SimpleOPENCL) {
Simple<DeviceType::OPENCL>();
} }
TEST_F(BatchNormOpTest, SimpleNeon) { TEST_F(BatchNormOpTest, SimpleRandomNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
auto &net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.operator_def());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height, width});
net.AddRandomInput<DeviceType::CPU, float>("Scale", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Offset", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Mean", {channels});
net.AddRandomInput<DeviceType::CPU, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::CPU, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(BatchNormOpTest, ComplexRandomNeon) {
srand(time(NULL)); srand(time(NULL));
// generate random input // generate random input
...@@ -74,11 +128,95 @@ TEST_F(BatchNormOpTest, SimpleNeon) { ...@@ -74,11 +128,95 @@ TEST_F(BatchNormOpTest, SimpleNeon) {
net.RunOp(); net.RunOp();
// Check // Check
Tensor *expected = net.GetOutput("Output"); Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON // Run NEON
net.RunOp(DeviceType::NEON); net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
} }
TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 64;
index_t width = 64;
// Construct graph
auto &net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.operator_def());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>("Input", {batch, channels, height, width});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::OPENCL, float>("Epsilon", {}, {1e-3});
// Run NEON
net.RunOp(DeviceType::OPENCL);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// run cpu
net.RunOp();
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 103;
index_t width = 113;
// Construct graph
auto &net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.Output("Output")
.Finalize(net.operator_def());
// Add input data
net.AddRandomInput<DeviceType::OPENCL, float>("Input", {batch, channels, height, width});
net.AddRandomInput<DeviceType::OPENCL, float>("Scale", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::OPENCL, float>("Epsilon", {}, {1e-3});
// Run NEON
net.RunOp(DeviceType::OPENCL);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// run cpu
net.RunOp();
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-2);
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册