From 89bb3a0bfdc0ab36cdd4bd0a69991b111db15223 Mon Sep 17 00:00:00 2001 From: liuruilong Date: Fri, 19 Oct 2018 16:01:19 +0800 Subject: [PATCH] update softmax reshape kernel --- src/framework/cl/cl_helper.h | 22 +++---- src/framework/cl/cl_image.cpp | 3 + src/framework/cl/cl_image.h | 19 ++++-- src/framework/executor.cpp | 2 +- src/operators/kernel/cl/cl_kernel/reshape.cl | 27 +++++++++ src/operators/kernel/cl/cl_kernel/softmax.cl | 37 ++++++++++++ .../kernel/cl/conv_add_bn_relu_kernel.cpp | 58 +++++++++---------- src/operators/kernel/cl/conv_add_kernel.cpp | 2 +- src/operators/kernel/cl/conv_kernel.cpp | 2 +- .../kernel/cl/depthwise_conv_kernel.cpp | 2 +- src/operators/kernel/cl/reshape_kernel.cpp | 3 + src/operators/kernel/cl/softmax_kernel.cpp | 47 ++++++++++----- 12 files changed, 159 insertions(+), 65 deletions(-) diff --git a/src/framework/cl/cl_helper.h b/src/framework/cl/cl_helper.h index 8640f6b1a4..6f3f83e272 100644 --- a/src/framework/cl/cl_helper.h +++ b/src/framework/cl/cl_helper.h @@ -49,34 +49,26 @@ class CLHelper { cl_context CLContext() { return scope_->Context(); } std::vector DefaultWorkSize(const CLImage &image) { + if (image.GetImageType() == Invalid) { + PADDLE_MOBILE_THROW_EXCEPTION(" not support image type"); + } // n c h w auto image_dim = image.dims(); if (image_dim.size() == 4) { auto n = image_dim[0]; auto h = image_dim[2]; auto w = image_dim[3]; - auto image_width = image.ImageWidth(); - auto work_size_0 = image_width / w; - auto work_size_1 = w; - auto work_size_2 = n * h; - return {work_size_0, work_size_1, work_size_2}; } else if (image_dim.size() == 2) { - auto image_width = image.ImageWidth(); - - auto work_size_0 = image_width / image_dim[1]; - - auto work_size_1 = image_dim[1]; - - auto work_size_2 = image_dim[0]; - - return {work_size_0, work_size_1, work_size_2}; + return {1, image.ImageWidth(), image.ImageHeight()}; + } else if (image_dim.size() == 1) { + return {1, image.ImageWidth(), 1}; } - PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp"); + PADDLE_MOBILE_THROW_EXCEPTION(" not support this dim, need imp "); } private: diff --git a/src/framework/cl/cl_image.cpp b/src/framework/cl/cl_image.cpp index fe83f1a6fe..2cb4f4ecea 100644 --- a/src/framework/cl/cl_image.cpp +++ b/src/framework/cl/cl_image.cpp @@ -119,6 +119,9 @@ void TensorToCLImage(const Tensor *tensor, CLImage *cl_image, } #ifdef PADDLE_MOBILE_DEBUG Print &operator<<(Print &printer, const CLImage &cl_image) { + if (cl_image.GetImageType() == Invalid) { + PADDLE_MOBILE_THROW_EXCEPTION(" not support image type"); + } printer << " dims: " << cl_image.dims() << "\n"; int stride = cl_image.numel() / 20; stride = stride > 0 ? stride : 1; diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index 70b5aab7cf..2162d303ac 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "framework/cl/cl_half.h" #include "framework/cl/cl_tool.h" +#include "framework/cl/cl_deleter.h" #include "framework/ddim.h" #include "framework/tensor.h" @@ -88,11 +89,20 @@ class CLImage { " empty image tensor data shouldn't have value"); } DLOG << " init empty image "; - InitCLImage(context, command_queue, nullptr, dim); + if (tensor_dims_.size() <= 2) { + DLOG << " dim <= 2 folder ~~~~~ "; + InitCLImage2C(context, command_queue, tensor_data_, tensor_dims_); + } else { + DLOG << " dim > 2 norm ~~~~~ "; + InitCLImage(context, command_queue, tensor_data_, tensor_dims_); + } + + +// InitCLImage(context, command_queue, nullptr, dim); initialized_ = true; } - cl_mem GetCLImage() const { return cl_image_; } + cl_mem GetCLImage() const { return cl_image_.get(); } const DDim &ImageDims() const { return image_dims_; } @@ -201,12 +211,13 @@ class CLImage { }; cid.buffer = nullptr; cl_int err; - cl_image_ = clCreateImage( + cl_mem cl_image = clCreateImage( context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0), &cf, // const cl_image_format *image_format &cid, // const cl_image_desc *image_desc data, // void *host_ptr &err); + cl_image_.reset(cl_image); if (err != CL_SUCCESS) { CL_CHECK_ERRORS(err); PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error "); @@ -283,7 +294,7 @@ class CLImage { } bool initialized_ = false; - cl_mem cl_image_; + std::unique_ptr<_cl_mem, CLMemDeleter> cl_image_; size_t image_width_; size_t width_of_one_block_; size_t height_of_one_block_; diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 4e4fc5c0cb..6f09aaf443 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -37,7 +37,7 @@ limitations under the License. */ #include "framework/cl/cl_image.h" #endif -int debug_to = 33; +int debug_to = 32; namespace paddle_mobile { namespace framework { diff --git a/src/operators/kernel/cl/cl_kernel/reshape.cl b/src/operators/kernel/cl/cl_kernel/reshape.cl index 062ba55de0..0ffc64f15c 100644 --- a/src/operators/kernel/cl/cl_kernel/reshape.cl +++ b/src/operators/kernel/cl/cl_kernel/reshape.cl @@ -14,6 +14,31 @@ limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void reshape(__read_only image2d_t input, + __write_only image2d_t output, + __private const int d0, + __private const int d1, + __private const int d2, + __private const int d3, + __private const int x0, + __private const int x1, + __private const int x2, + __private const int x3) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + half4 in = read_imageh(input, sampler, (int2)(x, y)); + + write_imageh(output, (int2)(x, y), in); +} + + +/* + __kernel void reshape(__read_only image2d_t input, __write_only image2d_t output, __private const int d0, @@ -49,3 +74,5 @@ __kernel void reshape(__read_only image2d_t input, } write_imageh(output, (int2)(x, y), r); } + +*/ diff --git a/src/operators/kernel/cl/cl_kernel/softmax.cl b/src/operators/kernel/cl/cl_kernel/softmax.cl index ba5cee7358..24279e1494 100644 --- a/src/operators/kernel/cl/cl_kernel/softmax.cl +++ b/src/operators/kernel/cl/cl_kernel/softmax.cl @@ -14,6 +14,41 @@ limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable +__kernel void softmax(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int group + ) { + const int out_c = get_global_id(0); // block index + const int out_w = get_global_id(1); // index in one block + const int out_nh = get_global_id(2); + + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + half maxv = 0.0f; + for (int i = 0; i < group; ++i) { + half4 temp = read_imageh(input_image, sampler, (int2)(i, 0)); + maxv = max(maxv, max(temp.x, max(temp.y, max(temp.z, temp.w)))); + } + + half4 rsum = (half4)(0.0f); + + for (int i = 0; i < group; ++i) { + half4 r = read_imageh(input_image, sampler, (int2)(i, 0)); + rsum += exp(r - maxv); + } + + half sum = rsum.x + rsum.y + rsum.z + rsum.w; + + half4 rr = read_imageh(input_image, sampler, (int2)(out_w, out_nh)); + half4 result = exp(rr - maxv) / sum; + write_imageh(output_image, (int2)(out_w, out_nh), result); +} + +/* + __kernel void softmax(__read_only image2d_t input, __write_only image2d_t output, __private const int d0, @@ -42,3 +77,5 @@ __kernel void softmax(__read_only image2d_t input, write_imageh(output, (int2)(z * d3 + x, y), r); } + +*/ diff --git a/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp index 061562d61d..79ccffb3ed 100644 --- a/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp @@ -43,21 +43,21 @@ bool ConvAddBNReluKernel::Init( const int C = mean->numel(); - for (int j = 0; j < C; ++j) { - DLOG << " mean - " << j << mean->data()[j]; - } - - for (int j = 0; j < C; ++j) { - DLOG << " variance - " << j << variance->data()[j]; - } - - for (int j = 0; j < C; ++j) { - DLOG << " scale - " << j << scale->data()[j]; - } - - for (int j = 0; j < C; ++j) { - DLOG << " bias - " << j << bias->data()[j]; - } +// for (int j = 0; j < C; ++j) { +// DLOG << " mean - " << j << mean->data()[j]; +// } +// +// for (int j = 0; j < C; ++j) { +// DLOG << " variance - " << j << variance->data()[j]; +// } +// +// for (int j = 0; j < C; ++j) { +// DLOG << " scale - " << j << scale->data()[j]; +// } +// +// for (int j = 0; j < C; ++j) { +// DLOG << " bias - " << j << bias->data()[j]; +// } // // DLOG << " climage mean: " << *mean; @@ -85,21 +85,21 @@ bool ConvAddBNReluKernel::Init( framework::CLImage *new_scale = new framework::CLImage(); - for (int j = 0; j < C; ++j) { - DLOG << " new scale - " << j << new_scale_ptr[j]; - } - - for (int j = 0; j < C; ++j) { - DLOG << " new bias - " << j << new_bias_ptr[j]; - } +// for (int j = 0; j < C; ++j) { +// DLOG << " new scale - " << j << new_scale_ptr[j]; +// } +// +// for (int j = 0; j < C; ++j) { +// DLOG << " new bias - " << j << new_bias_ptr[j]; +// } new_scale->SetTensorData(new_scale_ptr, variance->dims()); new_scale->InitCLImage(this->cl_helper_.CLContext(), cl_helper_.CLCommandQueue()); - DLOG << " climage - y bias: " << *(param->Bias()); - - DLOG << " climage - new scale: " << *new_scale; +// DLOG << " climage - y bias: " << *(param->Bias()); +// +// DLOG << " climage - new scale: " << *new_scale; framework::CLImage *new_bias = new framework::CLImage(); @@ -107,9 +107,9 @@ bool ConvAddBNReluKernel::Init( new_bias->InitCLImage(this->cl_helper_.CLContext(), cl_helper_.CLCommandQueue()); - DLOG << " climage - new bias: " << *new_bias; - - DLOG << " climage - filter: " << *(param->Filter()); +// DLOG << " climage - new bias: " << *new_bias; +// +// DLOG << " climage - filter: " << *(param->Filter()); param->SetNewScale(new_scale); param->SetNewBias(new_bias); @@ -237,7 +237,7 @@ void ConvAddBNReluKernel::Compute( CL_CHECK_ERRORS(status); status = - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); } diff --git a/src/operators/kernel/cl/conv_add_kernel.cpp b/src/operators/kernel/cl/conv_add_kernel.cpp index ff1dbfd2a3..e835931873 100644 --- a/src/operators/kernel/cl/conv_add_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_kernel.cpp @@ -118,7 +118,7 @@ void ConvAddKernel::Compute( CL_CHECK_ERRORS(status); status = - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); } diff --git a/src/operators/kernel/cl/conv_kernel.cpp b/src/operators/kernel/cl/conv_kernel.cpp index 27ebe18baf..0c5ab87d6d 100644 --- a/src/operators/kernel/cl/conv_kernel.cpp +++ b/src/operators/kernel/cl/conv_kernel.cpp @@ -155,7 +155,7 @@ void ConvKernel::Compute(const ConvParam ¶m) { DLOG << " begin enqueue "; status = - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); DLOG << " end enqueue "; diff --git a/src/operators/kernel/cl/depthwise_conv_kernel.cpp b/src/operators/kernel/cl/depthwise_conv_kernel.cpp index b84a56ca97..bbf4c07fc2 100644 --- a/src/operators/kernel/cl/depthwise_conv_kernel.cpp +++ b/src/operators/kernel/cl/depthwise_conv_kernel.cpp @@ -77,7 +77,7 @@ void DepthwiseConvKernel::Compute( CL_CHECK_ERRORS(status); status = - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); diff --git a/src/operators/kernel/cl/reshape_kernel.cpp b/src/operators/kernel/cl/reshape_kernel.cpp index b0d1537da2..98ac780dbd 100644 --- a/src/operators/kernel/cl/reshape_kernel.cpp +++ b/src/operators/kernel/cl/reshape_kernel.cpp @@ -36,9 +36,12 @@ void ReshapeKernel::Compute(const ReshapeParam ¶m) { const auto &outputDim = output->dims(); int dims[4] = {1, 1, 1, 1}; int odims[4] = {1, 1, 1, 1}; + // 1 1000 1 1 for (int i = 0; i < inputDim.size(); i++) { dims[4 - inputDim.size() + i] = inputDim[i]; } + + // 1 1 1 1000 for (int i = 0; i < outputDim.size(); i++) { odims[4 - outputDim.size() + i] = outputDim[i]; } diff --git a/src/operators/kernel/cl/softmax_kernel.cpp b/src/operators/kernel/cl/softmax_kernel.cpp index a8196cf376..d178b52318 100644 --- a/src/operators/kernel/cl/softmax_kernel.cpp +++ b/src/operators/kernel/cl/softmax_kernel.cpp @@ -33,20 +33,41 @@ void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { auto *output = param.Out(); auto inputImage = input->GetCLImage(); auto outputImage = output->GetCLImage(); - clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); - clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); - const auto &inputDim = input->dims(); - int dims[4] = {1, 1, 1, 1}; - for (int i = 0; i < inputDim.size(); i++) { - dims[4 - inputDim.size() + i] = inputDim[i]; - } - clSetKernelArg(kernel, 2, sizeof(int), &dims); - clSetKernelArg(kernel, 3, sizeof(int), &dims[1]); - clSetKernelArg(kernel, 4, sizeof(int), &dims[2]); - clSetKernelArg(kernel, 5, sizeof(int), &dims[3]); - - clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, + + DLOG << " softmax - output image dim " << output->ImageDims(); + DLOG << " softmax - output image tensor dim " << output->dims(); + + int group = output->ImageWidth(); + + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 2, sizeof(int), &group); + CL_CHECK_ERRORS(status); + +// const auto &inputDim = input->dims(); +// +// int dims[4] = {1, 1, 1, 1}; +// +// for (int i = 0; i < inputDim.size(); i++) { +// dims[4 - inputDim.size() + i] = inputDim[i]; +// } +// +// clSetKernelArg(kernel, 2, sizeof(int), &dims); +// clSetKernelArg(kernel, 3, sizeof(int), &dims[1]); +// clSetKernelArg(kernel, 4, sizeof(int), &dims[2]); +// clSetKernelArg(kernel, 5, sizeof(int), &dims[3]); + DLOG << "default_work_size: " << default_work_size; + + status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + } template class SoftmaxKernel; -- GitLab