From bcec92d0c164639eadf3595b0ed21fc1b403a085 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 25 Oct 2017 13:53:17 +0800 Subject: [PATCH] Add opencl batch norm kernel and fix bugs. --- mace/kernels/batch_norm.h | 73 ++++++---- mace/kernels/neon/batch_norm_neon.cc | 44 +++--- mace/kernels/opencl/batch_norm_opencl.cc | 50 +++++++ mace/kernels/opencl/cl/batch_norm.cl | 19 +++ mace/ops/batch_norm.cc | 2 + mace/ops/batch_norm.h | 15 +-- mace/ops/batch_norm_benchmark.cc | 15 ++- mace/ops/batch_norm_test.cc | 164 +++++++++++++++++++++-- 8 files changed, 305 insertions(+), 77 deletions(-) create mode 100644 mace/kernels/opencl/batch_norm_opencl.cc create mode 100644 mace/kernels/opencl/cl/batch_norm.cl diff --git a/mace/kernels/batch_norm.h b/mace/kernels/batch_norm.h index 5c838be4..cd3fb4b9 100644 --- a/mace/kernels/batch_norm.h +++ b/mace/kernels/batch_norm.h @@ -13,16 +13,13 @@ namespace kernels { template struct BatchNormFunctor { - void operator()(const T *input, - const T *scale, - const T *offset, - const T *mean, - const T *var, - const float variance_epsilon, - const index_t n, - const index_t channel, - const index_t sample_size, - T *output) { + void operator()(const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output) { // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // The calculation formula for inference is // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + @@ -31,16 +28,35 @@ struct BatchNormFunctor { // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; - T new_scale, new_offset; + const index_t n = input->dim(0); + const index_t channel = input->dim(1); + const index_t sample_size = input->dim(2) * input->dim(3); + + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard scale_mapper(scale); + Tensor::MappingGuard offset_mapper(offset); + Tensor::MappingGuard mean_mapper(mean); + Tensor::MappingGuard var_mapper(var); + Tensor::MappingGuard epsilon_mapper(epsilon); + Tensor::MappingGuard output_mapper(output); + + const T *input_ptr = input->data(); + const T *scale_ptr = scale->data(); + const T *offset_ptr = offset->data(); + const T *mean_ptr = mean->data(); + const T *var_ptr = var->data(); + const T *epsilon_ptr = epsilon->data(); + T *output_ptr = output->mutable_data(); + #pragma omp parallel for for (index_t c = 0; c < channel; ++c) { - new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); - new_offset = offset[c] - mean[c] * new_scale; + T new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); + T new_offset = offset_ptr[c] - mean_ptr[c] * new_scale; index_t pos = c * sample_size; for (index_t i = 0; i < n; ++i) { - const T *input_sample_ptr = input + pos; - T *output_sample_ptr = output + pos; + const T *input_sample_ptr = input_ptr + pos; + T *output_sample_ptr = output_ptr + pos; for (index_t j = 0; j < sample_size; ++j) { output_sample_ptr[j] = new_scale * input_sample_ptr[j] + new_offset; } @@ -52,16 +68,23 @@ struct BatchNormFunctor { template <> void BatchNormFunctor::operator()( - const float *input, - const float *scale, - const float *offset, - const float *mean, - const float *var, - const float variance_epsilon, - const index_t n, - const index_t channel, - const index_t sample_size, - float *output); + const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output); + +template <> +void BatchNormFunctor::operator()( + const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output); } // namepsace kernels } // namespace mace diff --git a/mace/kernels/neon/batch_norm_neon.cc b/mace/kernels/neon/batch_norm_neon.cc index cd5fff22..295cc59d 100644 --- a/mace/kernels/neon/batch_norm_neon.cc +++ b/mace/kernels/neon/batch_norm_neon.cc @@ -10,38 +10,46 @@ namespace kernels { template <> void BatchNormFunctor::operator()( - const float *input, - const float *scale, - const float *offset, - const float *mean, - const float *var, - const float variance_epsilon, - const index_t n, - const index_t channel, - const index_t sample_size, - float *output) { + const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output) { // Batch normalization in the paper https://arxiv.org/abs/1502.03167 . // The calculation formula for inference is - // Y = \frac{ \scale } { \sqrt{var+\variance_epsilon} } * X + - // ( \offset - \frac { \scale * mean } { \sqrt{var+\variance_epsilon} + // Y = \frac{ \scale } { \sqrt{var+\epsilon} } * X + + // ( \offset - \frac { \scale * mean } { \sqrt{var+\epsilon} // } - // new_scale = \frac{ \scale } { \sqrt{var+\variance_epsilon} } + // new_scale = \frac{ \scale } { \sqrt{var+\epsilon} } // new_offset = \offset - mean * common_val; // Y = new_scale * X + new_offset; - float new_scale, new_offset; + const index_t n = input->dim(0); + const index_t channel = input->dim(1); + const index_t sample_size = input->dim(2) * input->dim(3); + + const float *input_ptr = input->data(); + const float *scale_ptr = scale->data(); + const float *offset_ptr = offset->data(); + const float *mean_ptr = mean->data(); + const float *var_ptr = var->data(); + const float *epsilon_ptr = epsilon->data(); + float *output_ptr = output->mutable_data(); + index_t count = sample_size >> 2; index_t remain_count = sample_size - (count << 2); #pragma omp parallel for for (index_t c = 0; c < channel; ++c) { - new_scale = scale[c] / std::sqrt(var[c] + variance_epsilon); - new_offset = offset[c] - mean[c] * new_scale; + float new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr); + float new_offset = offset_ptr[c] - mean_ptr[c] * new_scale; index_t pos = c * sample_size; float32x4_t new_scale_f = vdupq_n_f32(new_scale); float32x4_t new_offset_f = vdupq_n_f32(new_offset); for (index_t i = 0; i < n; ++i) { - const float *input_sample_ptr = input + pos; - float *output_sample_ptr = output + pos; + const float *input_sample_ptr = input_ptr + pos; + float *output_sample_ptr = output_ptr + pos; for (index_t j = 0; j < count; ++j) { float32x4_t input_f = vld1q_f32(input_sample_ptr); diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc new file mode 100644 index 00000000..8999750f --- /dev/null +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -0,0 +1,50 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/batch_norm.h" +#include "mace/core/runtime/opencl/cl2.hpp" +#include "mace/core/runtime/opencl/opencl_runtime.h" + +namespace mace { +namespace kernels { + +template <> +void BatchNormFunctor::operator()( + const Tensor *input, + const Tensor *scale, + const Tensor *offset, + const Tensor *mean, + const Tensor *var, + const Tensor *epsilon, + Tensor *output) { + const index_t n = input->dim(0); + const index_t channel = input->dim(1); + const index_t sample_size = input->dim(2) * input->dim(3); + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + auto batch_norm_kernel = + cl::KernelFunctor(program, "batch_norm"); + cl_int error; + auto res_event = batch_norm_kernel(cl::EnqueueArgs(runtime->command_queue(), + cl::NDRange(n * channel * sample_size), + cl::NDRange(128)), + *(static_cast(input->buffer())), + *(static_cast(scale->buffer())), + *(static_cast(offset->buffer())), + *(static_cast(mean->buffer())), + *(static_cast(var->buffer())), + *(static_cast(epsilon->buffer())), + static_cast(channel), + static_cast(sample_size), + *(static_cast(output->buffer())), + error); + res_event.wait(); + MACE_CHECK(error == CL_SUCCESS); +} + +} // namespace kernels +} // namespace mace \ No newline at end of file diff --git a/mace/kernels/opencl/cl/batch_norm.cl b/mace/kernels/opencl/cl/batch_norm.cl new file mode 100644 index 00000000..f0d1b77e --- /dev/null +++ b/mace/kernels/opencl/cl/batch_norm.cl @@ -0,0 +1,19 @@ +void kernel batch_norm(global const float *input, + global const float *scale, + global const float *offset, + global const float *mean, + global const float *var, + global const float *epsilon, + private const int channels, + private const int pixels, + global float *output) { + int idx = get_global_id(0); + int channel = (idx % (channels * pixels)) / pixels; + + const float *input_ptr = input + idx; + const float new_scale = scale[channel] * rsqrt(var[channel] + *epsilon); + const float new_offset = offset[channel] - mean[channel] * new_scale; + float *output_ptr = output + idx; + *output_ptr = new_scale * *input_ptr + new_offset; +} + diff --git a/mace/ops/batch_norm.cc b/mace/ops/batch_norm.cc index f5b050f1..1ce9b1e0 100644 --- a/mace/ops/batch_norm.cc +++ b/mace/ops/batch_norm.cc @@ -12,4 +12,6 @@ REGISTER_CPU_OPERATOR(BatchNorm, BatchNormOp); REGISTER_NEON_OPERATOR(BatchNorm, BatchNormOp); #endif // __ARM_NEON +REGISTER_OPENCL_OPERATOR(BatchNorm, BatchNormOp); + } // namespace mace \ No newline at end of file diff --git a/mace/ops/batch_norm.h b/mace/ops/batch_norm.h index a7292601..1452bc72 100644 --- a/mace/ops/batch_norm.h +++ b/mace/ops/batch_norm.h @@ -40,20 +40,7 @@ class BatchNormOp : public Operator { Tensor *output = this->Output(0); output->ResizeLike(input); - const index_t n = input->dim(0); - const index_t channel = input->dim(1); - const index_t sample_size = input->dim(2) * input->dim(3); - - const T *input_ptr = input->data(); - const T *scale_ptr = scale->data(); - const T *offset_ptr = offset->data(); - const T *mean_ptr = mean->data(); - const T *var_ptr = var->data(); - const T *epsilon_ptr = epsilon->data(); - T *output_ptr = output->mutable_data(); - - functor_(input_ptr, scale_ptr, offset_ptr, mean_ptr, var_ptr, *epsilon_ptr, - n, channel, sample_size, output_ptr); + functor_(input, scale, offset, mean, var, epsilon, output); return true; } diff --git a/mace/ops/batch_norm_benchmark.cc b/mace/ops/batch_norm_benchmark.cc index 8fc24797..2276eaeb 100644 --- a/mace/ops/batch_norm_benchmark.cc +++ b/mace/ops/batch_norm_benchmark.cc @@ -24,12 +24,12 @@ static void BatchNorm( .Finalize(net.operator_def()); // Add input data - net.AddRandomInput("Input", {batch, channels, height, width}); - net.AddRandomInput("Scale", {channels}); - net.AddRandomInput("Offset", {channels}); - net.AddRandomInput("Mean", {channels}); - net.AddRandomInput("Var", {channels}, true); - net.AddInputFromArray("Epsilon", {}, {1e-3}); + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); // Warm-up for (int i = 0; i < 5; ++i) { @@ -54,7 +54,8 @@ static void BatchNorm( #define BM_BATCH_NORM(N, C, H, W, TYPE) \ BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \ - BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); + BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON); \ + BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, OPENCL); BM_BATCH_NORM(1, 1, 512, 512, float); BM_BATCH_NORM(1, 3, 128, 128, float); diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index 99338778..ceb49639 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -9,9 +9,10 @@ namespace mace { class BatchNormOpTest : public OpsTestBase {}; -TEST_F(BatchNormOpTest, SimpleCPU) { +template +void Simple() { // Construct graph - auto &net = test_net(); + OpsTestNet net; OpDefBuilder("BatchNorm", "BatchNormTest") .Input("Input") .Input("Scale") @@ -23,26 +24,79 @@ TEST_F(BatchNormOpTest, SimpleCPU) { .Finalize(net.operator_def()); // Add input data - net.AddInputFromArray("Input", {1, 1, 6, 2}, + net.AddInputFromArray("Input", {1, 1, 6, 2}, {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); - net.AddInputFromArray("Scale", {1}, {4.0f}); - net.AddInputFromArray("Offset", {1}, {2.0}); - net.AddInputFromArray("Mean", {1}, {10}); - net.AddInputFromArray("Var", {1}, {11.67f}); - net.AddInputFromArray("Epsilon", {}, {1e-3}); + net.AddInputFromArray("Scale", {1}, {4.0f}); + net.AddInputFromArray("Offset", {1}, {2.0}); + net.AddInputFromArray("Mean", {1}, {10}); + net.AddInputFromArray("Var", {1}, {11.67f}); + net.AddInputFromArray("Epsilon", {}, {1e-3}); // Run - net.RunOp(); + net.RunOp(D); // Check auto expected = CreateTensor({1, 1, 6, 2}, {-3.86, -3.86, -1.51, -1.51, 0.83, 0.83, 3.17, 3.17, 5.51, 5.51, 7.86, 7.86}); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.01); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-2); +} + +TEST_F(BatchNormOpTest, SimpleCPU) { + Simple(); +} + +TEST_F(BatchNormOpTest, SimpleNEON) { + Simple(); +} + +TEST_F(BatchNormOpTest, SimpleOPENCL) { + Simple(); } -TEST_F(BatchNormOpTest, SimpleNeon) { +TEST_F(BatchNormOpTest, SimpleRandomNeon) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 64; + index_t width = 64; + // Construct graph + auto &net = test_net(); + OpDefBuilder("BatchNorm", "BatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Input("Epsilon") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); + + // run cpu + net.RunOp(); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // Run NEON + net.RunOp(DeviceType::NEON); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); +} + +TEST_F(BatchNormOpTest, ComplexRandomNeon) { srand(time(NULL)); // generate random input @@ -74,11 +128,95 @@ TEST_F(BatchNormOpTest, SimpleNeon) { net.RunOp(); // Check - Tensor *expected = net.GetOutput("Output"); + Tensor expected; + expected.Copy(*net.GetOutput("Output")); // Run NEON net.RunOp(DeviceType::NEON); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); } + +TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 64; + index_t width = 64; + // Construct graph + auto &net = test_net(); + OpDefBuilder("BatchNorm", "BatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Input("Epsilon") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); + + // Run NEON + net.RunOp(DeviceType::OPENCL); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // run cpu + net.RunOp(); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); +} + +TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { + srand(time(NULL)); + + // generate random input + index_t batch = 1 + rand() % 10; + index_t channels = 3 + rand() % 50; + index_t height = 103; + index_t width = 113; + // Construct graph + auto &net = test_net(); + OpDefBuilder("BatchNorm", "BatchNormTest") + .Input("Input") + .Input("Scale") + .Input("Offset") + .Input("Mean") + .Input("Var") + .Input("Epsilon") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddRandomInput("Input", {batch, channels, height, width}); + net.AddRandomInput("Scale", {channels}); + net.AddRandomInput("Offset", {channels}); + net.AddRandomInput("Mean", {channels}); + net.AddRandomInput("Var", {channels}, true); + net.AddInputFromArray("Epsilon", {}, {1e-3}); + + // Run NEON + net.RunOp(DeviceType::OPENCL); + + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // run cpu + net.RunOp(); + + ExpectTensorNear(expected, *net.GetOutput("Output"), 1e-2); +} + } -- GitLab