提交 2aae87de 编写于 作者: R root 提交者: liuzhongkai

add opencl leakyrelu kernel

add opencl prelu kernel
上级 e73e9a9a
#pragma OPENCL EXTENSION cl_arm_printf : enable
#define SLICES 4
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define FLT4 float4
#define READ_FLT4 read_imagef
#define WRITE_FLT4 write_imagef
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void LeakyRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
const float alpha) {
// int B = input_shape.x; // size
// int H = input_shape.y; //
// int W = input_shape.z;
int C = input_shape.w;
int Y = get_global_id(0); // height id
int X = get_global_id(1); // weight id
for (int num = 0; num < UP_DIV(C, SLICES); ++num) {
FLT4 in_c4 = READ_FLT4(input, smp_zero, (int2)(X * UP_DIV(C, SLICES) + num, Y)); // NHWC4: H WC
FLT4 tmp;
tmp.x = in_c4.x >= 0 ? in_c4.x : in_c4.x * alpha;
tmp.y = in_c4.y >= 0 ? in_c4.y : in_c4.y * alpha;
tmp.z = in_c4.z >= 0 ? in_c4.z : in_c4.z * alpha;
tmp.w = in_c4.w >= 0 ? in_c4.w : in_c4.w * alpha;
WRITE_FLT4(output, (int2)(X * UP_DIV(C, SLICES) + num, Y), tmp); // NHWC4: H WC
}
}
......@@ -29,10 +29,12 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int ConvolutionOpenCLKernel::Init() {
static int count = 0;
std::cout << "ConvolutionOpenCLKernel::Init()\n";
std::set<std::string> build_options;
std::string source = CodeGen();
std::string program_name = "convolution";
std::string program_name = "convolution" + std::to_string(count);
count++;
std::string kernel_name = "convolution";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
......@@ -151,7 +153,11 @@ std::string ConvolutionOpenCLKernel::CodeGen() {
" }\n"
" }\n\n";
code += " FLT4 out0_c4_bias = out0_c4 + bias[co_slice];\n";
if (param->is_relu_) {
code += " out0_c4_bias = max(out0_c4_bias, (FLT4)(0.0f));\n";
} else if (param->is_relu6_) {
code += " out0_c4_bias = clamp(out0_c4_bias, (FLT4)(0.0f), (FLT4)(6.0f));\n";
}
// NHWC4 NHC4W4 NC4HW4
if (OW * CO_SLICES < 65536) {
code += " WRITE_FLT4(output, (int2)(ow * CO_SLICES + co_slice, oh), out0_c4_bias);// NHWC4: H WC\n}";
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <set>
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/kernel/opencl/kernel/leaky_relu.h"
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/cl/fp32/leaky_relu.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LeakyReLU;
namespace mindspore::kernel {
int LeakyReluOpenCLKernel::Init() {
if (inputs_[0]->shape().size() != 4) {
MS_LOG(ERROR) << "leaky_relu only support dim=4, but your dim=" << inputs_[0]->shape().size();
}
std::set<std::string> build_options;
std::string source = leaky_relu_source_fp32;
std::string program_name = "LeakyRelu";
std::string kernel_name = "LeakyRelu";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK;
}
int LeakyReluOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
int H = inputs_[0]->shape()[1];
int W = inputs_[0]->shape()[2];
int C = inputs_[0]->shape()[3];
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
img_size->push_back(W * UP_DIV(C, C4NUM));
img_size->push_back(H);
img_size->push_back(img_dtype);
return RET_OK;
}
int LeakyReluOpenCLKernel::Run() {
auto param = reinterpret_cast<LeakyReluParameter *>(this->opParameter);
MS_LOG(DEBUG) << this->Name() << " Running!";
int N = inputs_[0]->shape()[0];
int H = inputs_[0]->shape()[1];
int W = inputs_[0]->shape()[2];
int C = inputs_[0]->shape()[3];
cl_int4 input_shape = {N, H, W, C};
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
int arg_idx = 0;
ocl_runtime->SetKernelArg(kernel_, arg_idx++, inputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, outputs_[0]->Data());
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_shape);
ocl_runtime->SetKernelArg(kernel_, arg_idx++, param->alpha);
std::vector<size_t> local = {1, 1};
std::vector<size_t> global = {static_cast<size_t>(H), static_cast<size_t>(W)};
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
return 0;
}
kernel::LiteKernel *OpenCLLeakyReluKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
auto *kernel = new LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(opParameter), inputs, outputs);
if (inputs.size() == 0) {
MS_LOG(ERROR) << "Input data size must must be greater than 0, but your size is " << inputs.size();
}
if (inputs[0]->shape()[0] > 1) {
MS_LOG(ERROR) << "Init `leaky relu` kernel failed: Unsupported multi-batch.";
}
auto ret = kernel->Init();
if (0 != ret) {
MS_LOG(ERROR) << "Init `Leaky Relu` kernel failed!";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LeakyReLU, OpenCLLeakyReluKernelCreator)
} // namespace mindspore::kernel
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_
#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_
#include <vector>
#include "src/runtime/opencl/opencl_runtime.h"
#include "src/runtime/kernel/opencl/opencl_kernel.h"
struct LeakyReluParameter {
OpParameter op_parameter_;
cl_float alpha;
};
namespace mindspore::kernel {
class LeakyReluOpenCLKernel : public OpenCLKernel {
public:
explicit LeakyReluOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~LeakyReluOpenCLKernel() override{};
int Init() override;
int Run() override;
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
private:
cl::Kernel kernel_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_LEAKYRELU_H_
......@@ -142,6 +142,7 @@ if (SUPPORT_GPU)
${LITE_DIR}/src/runtime/kernel/opencl/kernel/matmul.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/softmax.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/concat.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/leaky_relu.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/conv2d_transpose.cc
${LITE_DIR}/src/runtime/kernel/opencl/kernel/transpose.cc
)
......@@ -318,6 +319,7 @@ if (SUPPORT_GPU)
${TEST_DIR}/ut/src/runtime/kernel/opencl/conv2d_transpose_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/transpose_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/convolution_tests.cc
${TEST_DIR}/ut/src/runtime/kernel/opencl/leakyrelu_tests.cc
)
endif()
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/common/file_utils.h"
#include "src/runtime/kernel/arm/nnacl/pack.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/leaky_relu.h"
namespace mindspore {
class TestLeakyReluOpenCL : public mindspore::Common {
public:
TestLeakyReluOpenCL() {}
};
void LoadDataLeakyRelu(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) {
memset(dst, 0x00, dst_size);
} else {
auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size));
memcpy(dst, src_data, dst_size);
}
}
void CompareOutLeakyRelu(lite::tensor::Tensor *output_tensor, const std::string &standard_answer_file) {
auto *output_data = reinterpret_cast<float *>(output_tensor->Data());
size_t output_size = output_tensor->Size();
auto expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(standard_answer_file.c_str(), &output_size));
constexpr float atol = 0.0002;
for (int i = 0; i < output_tensor->ElementsNum(); ++i) {
if (std::fabs(output_data[i] - expect_data[i]) > atol) {
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n", i, expect_data[i], output_data[i]);
printf("error at idx[%d] expect=%.3f output=%.3f\n\n\n", i, expect_data[i], output_data[i]);
return;
}
}
printf("compare success!\n");
printf("compare success!\n");
printf("compare success!\n\n\n");
}
void printf_tensor(mindspore::lite::tensor::Tensor *in_data) {
auto input_data = reinterpret_cast<float *>(in_data->Data());
for (int i = 0; i < in_data->ElementsNum(); ++i) {
printf("%f ", input_data[i]);
}
printf("\n");
MS_LOG(INFO) << "Print tensor done";
}
TEST_F(TestLeakyReluOpenCL, LeakyReluFp32_dim4) {
std::string in_file = "/data/local/tmp/in_data.bin";
std::string standard_answer_file = "/data/local/tmp/out_data.bin";
MS_LOG(INFO) << "Begin test:";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "Init tensors.";
std::vector<int> input_shape = {1, 4, 3, 8};
auto data_type = kNumberTypeFloat32;
auto tensor_type = schema::NodeType_ValueNode;
auto *input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
auto *output_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensor_type);
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
// freamework to do!!! allocate memory by hand
inputs[0]->MallocData(allocator);
auto param = new LeakyReluParameter();
param->alpha = 0.3;
auto *leakyrelu_kernel = new kernel::LeakyReluOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
leakyrelu_kernel->Init();
MS_LOG(INFO) << "initialize sub_graph";
std::vector<kernel::LiteKernel *> kernels{leakyrelu_kernel};
auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
sub_graph->Init();
MS_LOG(INFO) << "initialize input data";
LoadDataLeakyRelu(input_tensor->Data(), input_tensor->Size(), in_file);
MS_LOG(INFO) << "==================input data================";
printf_tensor(inputs[0]);
sub_graph->Run();
MS_LOG(INFO) << "==================output data================";
printf_tensor(outputs[0]);
CompareOutLeakyRelu(output_tensor, standard_answer_file);
}
} // namespace mindspore
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册