From 0ac4de9276cf280a51f5df19c1c797f31276a3da Mon Sep 17 00:00:00 2001 From: dolphin8 Date: Tue, 16 Oct 2018 17:27:33 +0800 Subject: [PATCH] cl_image --- src/framework/cl/cl_image.h | 97 +++++++++++++--------- src/operators/kernel/cl/reshape_kernel.cpp | 10 ++- 2 files changed, 66 insertions(+), 41 deletions(-) diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index 9bad4bb612..f7d86ec853 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -50,7 +50,11 @@ class CLImage { if (tensor_data_ == nullptr) { PADDLE_MOBILE_THROW_EXCEPTION(" need call SetTensorData first"); } - InitCLImage(context, tensor_data_, tensor_dims_); + if (tensor_dims_.size() <= 2) { + InitCLImage2C(context, tensor_data_, tensor_dims_); + } else { + InitCLImage(context, tensor_data_, tensor_dims_); + } delete[](tensor_data_); tensor_data_ = nullptr; initialized_ = true; @@ -118,6 +122,58 @@ class CLImage { const DDim &dims() const { return tensor_dims_; } private: + void InitCLImage2C(cl_context context, float *tensor_data, const DDim &dim) { + assert(dim.size() <= 2); + int tdim[2] = {1, 1}; + if (dim.size() == 1) { + tdim[1] = dim[0]; + } else { + tdim[0] = dim[0]; + tdim[1] = dim[1]; + } + int width = tdim[1] + 3 / 4; + int height = tdim[0]; + std::unique_ptr imageData{}; + if (tensor_data) { + imageData.reset(new half_t[width * height * 4]); + for (int h = 0; h < tdim[0]; h++) { + for (int w = 0; w < tdim[1]; w++) { + imageData[(h * width + w / 4) * 4 + (w % 4)] = Float2Half(tensor_data[h * tdim[1] + w]); + } + } + } + InitCLImage(context, width, height, imageData.get()); + } + + void InitCLImage(cl_context context, int width, int height, void *data) { + cl_image_format cf = {.image_channel_order = CL_RGBA, + .image_channel_data_type = CL_HALF_FLOAT}; + cl_image_desc cid = { + .image_type = CL_MEM_OBJECT_IMAGE2D, + .image_width = width, + .image_height = height, + .image_depth = 1, + .image_array_size = 1, + .image_row_pitch = 0, + .image_slice_pitch = 0, + .num_mip_levels = 0, + .num_samples = 0, + // .buffer = nullptr + }; + cid.buffer = nullptr; + cl_int err; + cl_image_ = clCreateImage( + context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0), + &cf, // const cl_image_format *image_format + &cid, // const cl_image_desc *image_desc + data, // void *host_ptr + &err + ); + if (err != CL_SUCCESS) { + CL_CHECK_ERRORS(err); + PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error "); + } + } void InitCLImage(cl_context context, float *tensor_data, const DDim &dim) { DLOG << " tensor dim: " << dim; // NCHW -> [W * (C+3)/4, H * N] @@ -160,17 +216,9 @@ class CLImage { for (int h = 0; h < H; h++) { size_t i2 = (i1 << 2) + c % 4; for (int w = 0; w < W; w++) { - if (i2 >= width * height * 4) { - printf("%d > %d ----> %d, %d, %d, %d --- %d, %d, %d\n", i2, - width * height * 4, n, c, h, w, i0, i1, i2); - } - assert(i2 < width * height * 4); - imageData[i2] = Float2Half(*p); i2 += 4; p++; - // count++; - // DLOG<(imageData.get()), // void *host_ptr - &err); - - if (err != CL_SUCCESS) { - CL_CHECK_ERRORS(err); - PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error "); - } + InitCLImage(context, width, height, imageData.get()); } bool initialized_ = false; diff --git a/src/operators/kernel/cl/reshape_kernel.cpp b/src/operators/kernel/cl/reshape_kernel.cpp index 877a325636..210932337c 100644 --- a/src/operators/kernel/cl/reshape_kernel.cpp +++ b/src/operators/kernel/cl/reshape_kernel.cpp @@ -34,8 +34,14 @@ void ReshapeKernel::Compute(const ReshapeParam ¶m) { clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); const auto &inputDim = input->dims(); const auto &outputDim = output->dims(); - int dims[4] = {inputDim[0], inputDim[1], inputDim[2], inputDim[3]}; - int odims[4] = {outputDim[0], outputDim[1], outputDim[2], outputDim[3]}; + int dims[4] = {1, 1, 1, 1}; + int odims[4] = {1, 1, 1, 1}; + for (int i = 0; i < inputDim.size(); i++) { + dims[4-inputDim.size()+i] = inputDim[i]; + } + for (int i = 0; i < outputDim.size(); i++) { + odims[4-outputDim.size()+i] = outputDim[i]; + } clSetKernelArg(kernel, 2, sizeof(int), dims); clSetKernelArg(kernel, 3, sizeof(int), dims + 1); clSetKernelArg(kernel, 4, sizeof(int), dims + 2); -- GitLab