From 2aae87decfd8bc7528f89e2d9921c135de9e94f8 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 10 Aug 2020 23:18:59 -0700 Subject: [PATCH] add opencl leakyrelu kernel add opencl prelu kernel --- .../kernel/opencl/cl/fp32/leaky_relu.cl | 29 +++++ .../kernel/opencl/kernel/convolution.cc | 10 +- .../kernel/opencl/kernel/leaky_relu.cc | 114 ++++++++++++++++++ .../runtime/kernel/opencl/kernel/leaky_relu.h | 49 ++++++++ mindspore/lite/test/CMakeLists.txt | 2 + .../runtime/kernel/opencl/leakyrelu_tests.cc | 109 +++++++++++++++++ 6 files changed, 311 insertions(+), 2 deletions(-) create mode 100644 mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc create mode 100644 mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl new file mode 100644 index 000000000..0330b8590 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl @@ -0,0 +1,29 @@ +#pragma OPENCL EXTENSION cl_arm_printf : enable + +#define SLICES 4 +#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) +#define FLT4 float4 +#define READ_FLT4 read_imagef +#define WRITE_FLT4 write_imagef +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, + const float alpha) { + // int B = input_shape.x; // size + // int H = input_shape.y; // + // int W = input_shape.z; + int C = input_shape.w; + + int Y = get_global_id(0); // height id + int X = get_global_id(1); // weight id + + for (int num = 0; num < UP_DIV(C, SLICES); ++num) { + FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC + FLT4 tmp; + tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha; + tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha; + tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha; + tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha; + WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC + } +} diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc index 9fc71ba38..238b1dba5 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc @@ -29,10 +29,12 @@ using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { int ConvolutionOpenCLKernel::Init() { + static int count = 0; std::cout << "ConvolutionOpenCLKernel::Init()\n"; std::set build_options; std::string source = CodeGen(); - std::string program_name = "convolution"; + std::string program_name = "convolution" + std::to_string(count); + count++; std::string kernel_name = "convolution"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); @@ -151,7 +153,11 @@ std::string ConvolutionOpenCLKernel::CodeGen() { " }\n" " }\n\n"; code += " FLT4 out0_c4_bias = out0_c4 + bias[co_slice];\n"; - + if (param->is_relu_) { + code += " out0_c4_bias = max(out0_c4_bias, (FLT4)(0.0f));\n"; + } else if (param->is_relu6_) { + code += " out0_c4_bias = clamp(out0_c4_bias, (FLT4)(0.0f), (FLT4)(6.0f));\n"; + } // NHWC4 NHC4W4 NC4HW4 if (OW * CO_SLICES < 65536) { code += " WRITE_FLT4(output, (int2)(ow * CO_SLICES + co_slice, oh), out0_c4_bias);// NHWC4: H WC\n}"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc new file mode 100644 index 000000000..d0c630e97 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.cc @@ -0,0 +1,114 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/runtime/kernel/opencl/kernel/leaky_relu.h" +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc" + +using mindspore::kernel::KERNEL_ARCH::kGPU; +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_LeakyReLU; + +namespace mindspore::kernel { + + int LeakyReluOpenCLKernel::Init() { + if (inputs_[0]->shape().size() != 4) { + MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << inputs_[0]->shape().size(); + } + std::set build_options; + std::string source = leaky_relu_source_fp32; + std::string program_name = "LeakyRelu"; + std::string kernel_name = "LeakyRelu"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->LoadSource(program_name, source); + ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); + + MS_LOG(DEBUG) << kernel_name << " Init Done!"; + return RET_OK; + } + + + int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + int H = inputs_[0]->shape()[1]; + int W = inputs_[0]->shape()[2]; + int C = inputs_[0]->shape()[3]; + +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + + img_size->clear(); + img_size->push_back(W * UP_DIV(C, C4NUM)); + img_size->push_back(H); + img_size->push_back(img_dtype); + return RET_OK; + } + + int LeakyReluOpenCLKernel::Run() { + auto param = reinterpret_cast(this->opParameter); + MS_LOG(DEBUG) << this->Name() << " Running!"; + int N = inputs_[0]->shape()[0]; + int H = inputs_[0]->shape()[1]; + int W = inputs_[0]->shape()[2]; + int C = inputs_[0]->shape()[3]; + cl_int4 input_shape = {N, H, W, C}; + + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + int arg_idx = 0; + ocl_runtime->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data()); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape); + ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha); + + std::vector local = {1, 1}; + std::vector global = {static_cast(H), static_cast(W)}; + ocl_runtime->RunKernel(kernel_, global, local, nullptr); + return 0; + } + + kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector &inputs, + const std::vector &outputs, + OpParameter *opParameter, const lite::Context *ctx, + const kernel::KernelKey &desc, const lite::Primitive *primitive) { + auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); + if (inputs.size() == 0) { + MS_LOG(ERROR) << "Input data size must must be greater than 0, but your size is " << inputs.size(); + } + if (inputs[0]->shape()[0] > 1) { + MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch."; + } + auto ret = kernel->Init(); + if (0 != ret) { + MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!"; + delete kernel; + return nullptr; + } + return kernel; + } + + REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator) +} // namespace mindspore::kernel + diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h new file mode 100644 index 000000000..8ad56bba1 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_ + +#include + +#include "src/runtime/opencl/opencl_runtime.h" +#include "src/runtime/kernel/opencl/opencl_kernel.h" + +struct LeakyReluParameter { + OpParameter op_parameter_; + cl_float alpha; +}; + +namespace mindspore::kernel { + +class LeakyReluOpenCLKernel : public OpenCLKernel { + public: + explicit LeakyReluOpenCLKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs) + : OpenCLKernel(parameter, inputs, outputs) {} + ~LeakyReluOpenCLKernel() override{}; + + int Init() override; + int Run() override; + int GetImageSize(size_t idx, std::vector *img_size) override; + + private: + cl::Kernel kernel_; +}; + +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_ diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index bd5a436e1..e50e6123b 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -142,6 +142,7 @@ if (SUPPORT_GPU) ${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc + ${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc ${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc ) @@ -318,6 +319,7 @@ if (SUPPORT_GPU) ${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc ${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc + ${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc ) endif() diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc new file mode 100644 index 000000000..8fa9a75d5 --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include "utils/log_adapter.h" +#include "common/common_test.h" +#include "mindspore/lite/src/common/file_utils.h" +#include "src/runtime/kernel/arm/nnacl/pack.h" +#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h" +#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" +#include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h" + +namespace mindspore { +class TestLeakyReluOpenCL : public mindspore::Common { + public: + TestLeakyReluOpenCL() {} +}; + +void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) { + if (file_path.empty()) { + memset(dst, 0x00, dst_size); + } else { + auto src_data = reinterpret_cast(mindspore::lite::ReadFile(file_path.c_str(), &dst_size)); + memcpy(dst, src_data, dst_size); + } +} + +void CompareOutLeakyRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) { + auto *output_data = reinterpret_cast(output_tensor->Data()); + size_t output_size = output_tensor->Size(); + auto expect_data = reinterpret_cast(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size)); + constexpr float atol = 0.0002; + for (int i = 0; i < output_tensor->ElementsNum(); ++i) { + if (std::fabs(output_data[i] - expect_data[i]) > atol) { + printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]); + printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]); + return; + } + } + printf("compare success!\n"); + printf("compare success!\n"); + printf("compare success!\n\n\n"); +} + +void printf_tensor(mindspore::lite::tensor::Tensor *in_data) { + auto input_data = reinterpret_cast(in_data->Data()); + for (int i = 0; i < in_data->ElementsNum(); ++i) { + printf("%f ", input_data[i]); + } + printf("\n"); + MS_LOG(INFO) << "Print tensor done"; +} + +TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) { + std::string in_file = "/data/local/tmp/in_data.bin"; + std::string standard_answer_file = "/data/local/tmp/out_data.bin"; + MS_LOG(INFO) << "Begin test:"; + auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); + ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); + + MS_LOG(INFO) << "Init tensors."; + std::vector input_shape = {1, 4, 3, 8}; + + auto data_type = kNumberTypeFloat32; + auto tensor_type = schema::NodeType_ValueNode; + auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); + auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type); + std::vector inputs{input_tensor}; + std::vector outputs{output_tensor}; + + // freamework to do!!! allocate memory by hand + inputs[0]->MallocData(allocator); + + auto param = new LeakyReluParameter(); + param->alpha = 0.3; + auto *leakyrelu_kernel = new kernel::LeakyReluOpenCLKernel(reinterpret_cast(param), inputs, outputs); + leakyrelu_kernel->Init(); + + MS_LOG(INFO) << "initialize sub_graph"; + std::vector kernels{leakyrelu_kernel}; + auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + sub_graph->Init(); + + MS_LOG(INFO) << "initialize input data"; + LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file); + MS_LOG(INFO) << "==================input data================"; + printf_tensor(inputs[0]); + + sub_graph->Run(); + + MS_LOG(INFO) << "==================output data================"; + printf_tensor(outputs[0]); + CompareOutLeakyRelu(output_tensor, standard_answer_file); +} +} // namespace mindspore -- GitLab