From 73af2b224c19255768c89f75af52e80226c8f825 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 8 Nov 2017 15:30:14 +0800 Subject: [PATCH] Finish relu opencl kernel. --- mace/kernels/neon/relu_neon.cc | 8 +- mace/kernels/opencl/cl/relu.cl | 32 +++++++ mace/kernels/opencl/relu_opencl.cc | 59 ++++++++++++ mace/kernels/relu.h | 17 ++-- mace/ops/relu.cc | 1 + mace/ops/relu.h | 5 +- mace/ops/relu_benchmark.cc | 7 +- mace/ops/relu_test.cc | 139 ++++++++++++++++++++++++----- 8 files changed, 231 insertions(+), 37 deletions(-) create mode 100644 mace/kernels/opencl/cl/relu.cl create mode 100644 mace/kernels/opencl/relu_opencl.cc diff --git a/mace/kernels/neon/relu_neon.cc b/mace/kernels/neon/relu_neon.cc index 7898e9b0..e2d983dd 100644 --- a/mace/kernels/neon/relu_neon.cc +++ b/mace/kernels/neon/relu_neon.cc @@ -9,9 +9,11 @@ namespace mace { namespace kernels { template <> -void ReluFunctor::operator()(const float *input, - float *output, - index_t size) { +void ReluFunctor::operator()(const Tensor *input_tensor, + Tensor *output_tensor) { + const float *input = input_tensor->data(); + float *output = output_tensor->mutable_data(); + index_t size = input_tensor->size(); if (max_limit_ < 0) { #pragma omp parallel for for (int64_t i = 0; i < size; i += kCostPerGroup) { diff --git a/mace/kernels/opencl/cl/relu.cl b/mace/kernels/opencl/cl/relu.cl new file mode 100644 index 00000000..390c8454 --- /dev/null +++ b/mace/kernels/opencl/cl/relu.cl @@ -0,0 +1,32 @@ +__kernel void relu(__global const float *input, + __private const int size, + __global float *output) { + int idx = get_global_id(0); + + if (idx + 4 > size) { + for(; idx < size; ++idx) { + *(output+idx) = fmax(*(input+idx), 0); + } + } else { + float4 data = vload4(idx, input); + data = fmax(data, (float4)0); + vstore4(data, idx, output); + } +} + +__kernel void relux(__global const float *input, + __private const float max_limit, + __private const int size, + __global float *output) { + int idx = get_global_id(0); + + if (idx + 4 > size) { + for(; idx < size; ++idx) { + *(output+idx) = clamp(*(input+idx), 0.0f, max_limit); + } + } else { + float4 data = vload4(idx, input); + data = clamp(data, (float4)0, (float4)max_limit); + vstore4(data, idx, output); + } +} diff --git a/mace/kernels/opencl/relu_opencl.cc b/mace/kernels/opencl/relu_opencl.cc new file mode 100644 index 00000000..c4d22ae8 --- /dev/null +++ b/mace/kernels/opencl/relu_opencl.cc @@ -0,0 +1,59 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/relu.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" + +namespace mace { +namespace kernels { + +template <> +void ReluFunctor::operator()(const Tensor *input, + Tensor *output) { + + index_t element_size = input->NumElements(); + index_t blocks = (element_size + 3) / 4; + + const uint32_t gws = blocks; + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + + if (max_limit_ < 0) { + auto relu_kernel = cl::Kernel(program, "relu"); + + const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(relu_kernel); + + uint32_t idx = 0; + relu_kernel.setArg(idx++, *(static_cast(input->buffer()))); + relu_kernel.setArg(idx++, static_cast(element_size)); + relu_kernel.setArg(idx++, *(static_cast(output->buffer()))); + + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + relu_kernel, cl::NullRange, + cl::NDRange(gws), + cl::NDRange(lws)); + MACE_CHECK(error == CL_SUCCESS); + } else { + auto relu_kernel = cl::Kernel(program, "relux"); + + const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(relu_kernel); + + uint32_t idx = 0; + relu_kernel.setArg(idx++, *(static_cast(input->buffer()))); + relu_kernel.setArg(idx++, max_limit_); + relu_kernel.setArg(idx++, static_cast(element_size)); + relu_kernel.setArg(idx++, *(static_cast(output->buffer()))); + + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + relu_kernel, cl::NullRange, + cl::NDRange(gws), + cl::NDRange(lws)); + MACE_CHECK(error == CL_SUCCESS); + } +} + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/relu.h b/mace/kernels/relu.h index 71cd07ab..347d586b 100644 --- a/mace/kernels/relu.h +++ b/mace/kernels/relu.h @@ -14,23 +14,28 @@ template struct ReluFunctor { T max_limit_; - void operator()(const T *input, T *output, index_t size) { + void operator()(const Tensor *input, Tensor *output) { + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + index_t size = input->size(); if (max_limit_ < 0) { for (index_t i = 0; i < size; ++i) { - output[i] = std::max(input[i], static_cast(0)); + output_ptr[i] = std::max(input_ptr[i], static_cast(0)); } } else { for (index_t i = 0; i < size; ++i) { - output[i] = std::min(std::max(input[i], static_cast(0)), max_limit_); + output_ptr[i] = std::min(std::max(input_ptr[i], static_cast(0)), max_limit_); } } } }; template <> -void ReluFunctor::operator()(const float *input, - float *output, - index_t size); +void ReluFunctor::operator()(const Tensor *input, + Tensor *output); +template <> +void ReluFunctor::operator()(const Tensor *input, + Tensor *output); } // namespace kernels } // namespace mace diff --git a/mace/ops/relu.cc b/mace/ops/relu.cc index 8602f932..c86fb38f 100644 --- a/mace/ops/relu.cc +++ b/mace/ops/relu.cc @@ -12,4 +12,5 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp); REGISTER_NEON_OPERATOR(Relu, ReluOp); #endif // __ARM_NEON +REGISTER_OPENCL_OPERATOR(Relu, ReluOp); } // namespace mace diff --git a/mace/ops/relu.h b/mace/ops/relu.h index 654130fa..fea49c8d 100644 --- a/mace/ops/relu.h +++ b/mace/ops/relu.h @@ -22,11 +22,8 @@ class ReluOp : public Operator { const Tensor *input_tensor = this->inputs_[0]; Tensor *output_tensor = this->outputs_[0]; output_tensor->ResizeLike(input_tensor); - const T *input = input_tensor->data(); - T *output = output_tensor->mutable_data(); - index_t size = input_tensor->size(); - functor_(input, output, size); + functor_(input_tensor, output_tensor); return true; } diff --git a/mace/ops/relu_benchmark.cc b/mace/ops/relu_benchmark.cc index a1fc6ed4..14badcd9 100644 --- a/mace/ops/relu_benchmark.cc +++ b/mace/ops/relu_benchmark.cc @@ -19,17 +19,19 @@ static void ReluBenchmark(int iters, int size) { .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("Input", {size}); + net.AddRandomInput("Input", {size}); // Warm-up for (int i = 0; i < 5; ++i) { net.RunOp(D); } + net.Sync(); mace::testing::StartTiming(); while (iters--) { net.RunOp(D); } + net.Sync(); } #define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \ @@ -43,7 +45,8 @@ static void ReluBenchmark(int iters, int size) { #define BM_RELU(SIZE, TYPE) \ BM_RELU_MACRO(SIZE, TYPE, CPU); \ - BM_RELU_MACRO(SIZE, TYPE, NEON); + BM_RELU_MACRO(SIZE, TYPE, NEON);\ + BM_RELU_MACRO(SIZE, TYPE, OPENCL); BM_RELU(1000, float); BM_RELU(100000, float); diff --git a/mace/ops/relu_test.cc b/mace/ops/relu_test.cc index 5a6eb7ca..56aace07 100644 --- a/mace/ops/relu_test.cc +++ b/mace/ops/relu_test.cc @@ -9,51 +9,146 @@ namespace mace { class ReluOpTest : public OpsTestBase {}; -TEST_F(ReluOpTest, ReluOp) { - // Construct graph - auto &net = test_net(); +template +void TestSimple() { + OpsTestNet net; OpDefBuilder("Relu", "ReluTest") .Input("Input") .Output("Output") .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("Input", {1, 2, 3, 5}); + net.AddInputFromArray("Input", + {2, 2, 2, 2}, + {-7, 7, -6, 6, -5, 5, -4, 4, + -3, 3, -2, 2, -1, 1, 0, 0}); // Run - net.RunOp(); + net.RunOp(D); - Tensor expected; - expected.Copy(*net.GetOutput("Output")); + auto expected = CreateTensor({2, 2, 2, 2}, + {0, 7, 0, 6, 0, 5, 0, 4, + 0, 3, 0, 2, 0, 1, 0, 0}); - // Check - net.RunOp(DeviceType::NEON); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(ReluOpTest, CPUSimple) { + TestSimple(); +} + +TEST_F(ReluOpTest, NEONSimple) { + TestSimple(); +} + +TEST_F(ReluOpTest, OPENCLSimple) { + TestSimple(); +} + +template +void TestUnalignedSimple() { + OpsTestNet net; + OpDefBuilder("Relu", "ReluTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); - ExpectTensorNear(expected, *net.GetOutput("Output"), 0.01); + // Add input data + net.AddInputFromArray("Input", + {1, 1, 3, 2}, + {-7, 7, -6, 6, -5, 5}); + + // Run + net.RunOp(D); + + auto expected = CreateTensor({1, 1, 3, 2}, + {0, 7, 0, 6, 0, 5}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(ReluOpTest, CPUUnalignedSimple) { + TestUnalignedSimple(); +} + +TEST_F(ReluOpTest, NEONUnalignedSimple) { + TestUnalignedSimple(); } -TEST_F(ReluOpTest, ReluOpWithMax) { - // Construct graph - auto &net = test_net(); - OpDefBuilder("Relu", "ReluTestWithMax") +TEST_F(ReluOpTest, OPENCLUnalignedSimple) { + TestUnalignedSimple(); +} + +template +void TestSimpleReluX() { + OpsTestNet net; + OpDefBuilder("Relu", "ReluTest") .Input("Input") .Output("Output") - .AddFloatArg("max_limit", 0.5) + .AddFloatArg("max_limit", 6) .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("Input", {1, 2, 3, 5}); + net.AddInputFromArray("Input", + {2, 2, 2, 2}, + {-7, 7, -6, 6, -5, 5, -4, 4, + -3, 3, -2, 2, -1, 1, 0, 0}); // Run - net.RunOp(); + net.RunOp(D); - Tensor expected; - expected.Copy(*net.GetOutput("Output")); + auto expected = CreateTensor({2, 2, 2, 2}, + {0, 6, 0, 6, 0, 5, 0, 4, + 0, 3, 0, 2, 0, 1, 0, 0}); - // Check - net.RunOp(DeviceType::NEON); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(ReluOpTest, CPUSimpleReluX) { + TestSimpleReluX(); +} + +TEST_F(ReluOpTest, NEONSimpleReluX) { + TestSimpleReluX(); +} + +TEST_F(ReluOpTest, OPENCLSimpleReluX) { + TestSimpleReluX(); +} + +template +void TestUnalignedSimpleReluX() { + OpsTestNet net; + OpDefBuilder("Relu", "ReluTest") + .Input("Input") + .Output("Output") + .AddFloatArg("max_limit", 6) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray("Input", + {1, 1, 1, 7}, + {-7, 7, -6, 6, -5, 5, -4}); + + // Run + net.RunOp(D); + + auto expected = CreateTensor({1, 1, 1, 7}, + {0, 6, 0, 6, 0, 5, 0}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(ReluOpTest, CPUUnalignedSimpleReluX) { + TestUnalignedSimpleReluX(); +} + +TEST_F(ReluOpTest, NEONUnalignedSimpleReluX) { + TestUnalignedSimpleReluX(); +} - ExpectTensorNear(expected, *net.GetOutput("Output"), 0.01); +TEST_F(ReluOpTest, OPENCLUnalignedSimpleReluX) { + TestUnalignedSimpleReluX(); } } // namespace mace -- GitLab