提交 73af2b22 编写于 作者: L liuqi

Finish relu opencl kernel.

上级 1998fd46
...@@ -9,9 +9,11 @@ namespace mace { ...@@ -9,9 +9,11 @@ namespace mace {
namespace kernels { namespace kernels {
template <> template <>
void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input, void ReluFunctor<DeviceType::NEON, float>::operator()(const Tensor *input_tensor,
float *output, Tensor *output_tensor) {
index_t size) { const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
index_t size = input_tensor->size();
if (max_limit_ < 0) { if (max_limit_ < 0) {
#pragma omp parallel for #pragma omp parallel for
for (int64_t i = 0; i < size; i += kCostPerGroup) { for (int64_t i = 0; i < size; i += kCostPerGroup) {
......
__kernel void relu(__global const float *input,
__private const int size,
__global float *output) {
int idx = get_global_id(0);
if (idx + 4 > size) {
for(; idx < size; ++idx) {
*(output+idx) = fmax(*(input+idx), 0);
}
} else {
float4 data = vload4(idx, input);
data = fmax(data, (float4)0);
vstore4(data, idx, output);
}
}
__kernel void relux(__global const float *input,
__private const float max_limit,
__private const int size,
__global float *output) {
int idx = get_global_id(0);
if (idx + 4 > size) {
for(; idx < size; ++idx) {
*(output+idx) = clamp(*(input+idx), 0.0f, max_limit);
}
} else {
float4 data = vload4(idx, input);
data = clamp(data, (float4)0, (float4)max_limit);
vstore4(data, idx, output);
}
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/relu.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
namespace kernels {
template <>
void ReluFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
Tensor *output) {
index_t element_size = input->NumElements();
index_t blocks = (element_size + 3) / 4;
const uint32_t gws = blocks;
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
if (max_limit_ < 0) {
auto relu_kernel = cl::Kernel(program, "relu");
const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(relu_kernel);
uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
relu_kernel.setArg(idx++, static_cast<int32_t>(element_size));
relu_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
relu_kernel, cl::NullRange,
cl::NDRange(gws),
cl::NDRange(lws));
MACE_CHECK(error == CL_SUCCESS);
} else {
auto relu_kernel = cl::Kernel(program, "relux");
const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(relu_kernel);
uint32_t idx = 0;
relu_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
relu_kernel.setArg(idx++, max_limit_);
relu_kernel.setArg(idx++, static_cast<int32_t>(element_size));
relu_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
relu_kernel, cl::NullRange,
cl::NDRange(gws),
cl::NDRange(lws));
MACE_CHECK(error == CL_SUCCESS);
}
}
} // namespace kernels
} // namespace mace
...@@ -14,23 +14,28 @@ template <DeviceType D, typename T> ...@@ -14,23 +14,28 @@ template <DeviceType D, typename T>
struct ReluFunctor { struct ReluFunctor {
T max_limit_; T max_limit_;
void operator()(const T *input, T *output, index_t size) { void operator()(const Tensor *input, Tensor *output) {
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
index_t size = input->size();
if (max_limit_ < 0) { if (max_limit_ < 0) {
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output[i] = std::max(input[i], static_cast<T>(0)); output_ptr[i] = std::max(input_ptr[i], static_cast<T>(0));
} }
} else { } else {
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output[i] = std::min(std::max(input[i], static_cast<T>(0)), max_limit_); output_ptr[i] = std::min(std::max(input_ptr[i], static_cast<T>(0)), max_limit_);
} }
} }
} }
}; };
template <> template <>
void ReluFunctor<DeviceType::NEON, float>::operator()(const float *input, void ReluFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
float *output, Tensor *output);
index_t size); template <>
void ReluFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
Tensor *output);
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -12,4 +12,5 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>); ...@@ -12,4 +12,5 @@ REGISTER_CPU_OPERATOR(Relu, ReluOp<DeviceType::CPU, float>);
REGISTER_NEON_OPERATOR(Relu, ReluOp<DeviceType::NEON, float>); REGISTER_NEON_OPERATOR(Relu, ReluOp<DeviceType::NEON, float>);
#endif // __ARM_NEON #endif // __ARM_NEON
REGISTER_OPENCL_OPERATOR(Relu, ReluOp<DeviceType::OPENCL, float>);
} // namespace mace } // namespace mace
...@@ -22,11 +22,8 @@ class ReluOp : public Operator<D, T> { ...@@ -22,11 +22,8 @@ class ReluOp : public Operator<D, T> {
const Tensor *input_tensor = this->inputs_[0]; const Tensor *input_tensor = this->inputs_[0];
Tensor *output_tensor = this->outputs_[0]; Tensor *output_tensor = this->outputs_[0];
output_tensor->ResizeLike(input_tensor); output_tensor->ResizeLike(input_tensor);
const T *input = input_tensor->data<T>();
T *output = output_tensor->mutable_data<T>();
index_t size = input_tensor->size();
functor_(input, output, size); functor_(input_tensor, output_tensor);
return true; return true;
} }
......
...@@ -19,17 +19,19 @@ static void ReluBenchmark(int iters, int size) { ...@@ -19,17 +19,19 @@ static void ReluBenchmark(int iters, int size) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {size}); net.AddRandomInput<D, float>("Input", {size});
// Warm-up // Warm-up
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
net.RunOp(D); net.RunOp(D);
} }
net.Sync();
mace::testing::StartTiming(); mace::testing::StartTiming();
while (iters--) { while (iters--) {
net.RunOp(D); net.RunOp(D);
} }
net.Sync();
} }
#define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \ #define BM_RELU_MACRO(SIZE, TYPE, DEVICE) \
...@@ -43,7 +45,8 @@ static void ReluBenchmark(int iters, int size) { ...@@ -43,7 +45,8 @@ static void ReluBenchmark(int iters, int size) {
#define BM_RELU(SIZE, TYPE) \ #define BM_RELU(SIZE, TYPE) \
BM_RELU_MACRO(SIZE, TYPE, CPU); \ BM_RELU_MACRO(SIZE, TYPE, CPU); \
BM_RELU_MACRO(SIZE, TYPE, NEON); BM_RELU_MACRO(SIZE, TYPE, NEON);\
BM_RELU_MACRO(SIZE, TYPE, OPENCL);
BM_RELU(1000, float); BM_RELU(1000, float);
BM_RELU(100000, float); BM_RELU(100000, float);
......
...@@ -9,51 +9,146 @@ namespace mace { ...@@ -9,51 +9,146 @@ namespace mace {
class ReluOpTest : public OpsTestBase {}; class ReluOpTest : public OpsTestBase {};
TEST_F(ReluOpTest, ReluOp) { template <DeviceType D>
// Construct graph void TestSimple() {
auto &net = test_net(); OpsTestNet net;
OpDefBuilder("Relu", "ReluTest") OpDefBuilder("Relu", "ReluTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {1, 2, 3, 5}); net.AddInputFromArray<D, float>("Input",
{2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4,
-3, 3, -2, 2, -1, 1, 0, 0});
// Run // Run
net.RunOp(); net.RunOp(D);
Tensor expected; auto expected = CreateTensor<float>({2, 2, 2, 2},
expected.Copy(*net.GetOutput("Output")); {0, 7, 0, 6, 0, 5, 0, 4,
0, 3, 0, 2, 0, 1, 0, 0});
// Check ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
net.RunOp(DeviceType::NEON); }
TEST_F(ReluOpTest, CPUSimple) {
TestSimple<DeviceType::CPU>();
}
TEST_F(ReluOpTest, NEONSimple) {
TestSimple<DeviceType::NEON>();
}
TEST_F(ReluOpTest, OPENCLSimple) {
TestSimple<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestUnalignedSimple() {
OpsTestNet net;
OpDefBuilder("Relu", "ReluTest")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01); // Add input data
net.AddInputFromArray<D, float>("Input",
{1, 1, 3, 2},
{-7, 7, -6, 6, -5, 5});
// Run
net.RunOp(D);
auto expected = CreateTensor<float>({1, 1, 3, 2},
{0, 7, 0, 6, 0, 5});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ReluOpTest, CPUUnalignedSimple) {
TestUnalignedSimple<DeviceType::CPU>();
}
TEST_F(ReluOpTest, NEONUnalignedSimple) {
TestUnalignedSimple<DeviceType::NEON>();
} }
TEST_F(ReluOpTest, ReluOpWithMax) { TEST_F(ReluOpTest, OPENCLUnalignedSimple) {
// Construct graph TestUnalignedSimple<DeviceType::OPENCL>();
auto &net = test_net(); }
OpDefBuilder("Relu", "ReluTestWithMax")
template <DeviceType D>
void TestSimpleReluX() {
OpsTestNet net;
OpDefBuilder("Relu", "ReluTest")
.Input("Input") .Input("Input")
.Output("Output") .Output("Output")
.AddFloatArg("max_limit", 0.5) .AddFloatArg("max_limit", 6)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {1, 2, 3, 5}); net.AddInputFromArray<D, float>("Input",
{2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4,
-3, 3, -2, 2, -1, 1, 0, 0});
// Run // Run
net.RunOp(); net.RunOp(D);
Tensor expected; auto expected = CreateTensor<float>({2, 2, 2, 2},
expected.Copy(*net.GetOutput("Output")); {0, 6, 0, 6, 0, 5, 0, 4,
0, 3, 0, 2, 0, 1, 0, 0});
// Check ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
net.RunOp(DeviceType::NEON); }
TEST_F(ReluOpTest, CPUSimpleReluX) {
TestSimpleReluX<DeviceType::CPU>();
}
TEST_F(ReluOpTest, NEONSimpleReluX) {
TestSimpleReluX<DeviceType::NEON>();
}
TEST_F(ReluOpTest, OPENCLSimpleReluX) {
TestSimpleReluX<DeviceType::OPENCL>();
}
template <DeviceType D>
void TestUnalignedSimpleReluX() {
OpsTestNet net;
OpDefBuilder("Relu", "ReluTest")
.Input("Input")
.Output("Output")
.AddFloatArg("max_limit", 6)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>("Input",
{1, 1, 1, 7},
{-7, 7, -6, 6, -5, 5, -4});
// Run
net.RunOp(D);
auto expected = CreateTensor<float>({1, 1, 1, 7},
{0, 6, 0, 6, 0, 5, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ReluOpTest, CPUUnalignedSimpleReluX) {
TestUnalignedSimpleReluX<DeviceType::CPU>();
}
TEST_F(ReluOpTest, NEONUnalignedSimpleReluX) {
TestUnalignedSimpleReluX<DeviceType::NEON>();
}
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01); TEST_F(ReluOpTest, OPENCLUnalignedSimpleReluX) {
TestUnalignedSimpleReluX<DeviceType::OPENCL>();
} }
} // namespace mace } // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册