提交 0ac4de92 编写于 作者: D dolphin8

cl_image

上级 9a5af4ba
......@@ -50,7 +50,11 @@ class CLImage {
if (tensor_data_ == nullptr) {
PADDLE_MOBILE_THROW_EXCEPTION(" need call SetTensorData first");
}
InitCLImage(context, tensor_data_, tensor_dims_);
if (tensor_dims_.size() <= 2) {
InitCLImage2C(context, tensor_data_, tensor_dims_);
} else {
InitCLImage(context, tensor_data_, tensor_dims_);
}
delete[](tensor_data_);
tensor_data_ = nullptr;
initialized_ = true;
......@@ -118,6 +122,58 @@ class CLImage {
const DDim &dims() const { return tensor_dims_; }
private:
void InitCLImage2C(cl_context context, float *tensor_data, const DDim &dim) {
assert(dim.size() <= 2);
int tdim[2] = {1, 1};
if (dim.size() == 1) {
tdim[1] = dim[0];
} else {
tdim[0] = dim[0];
tdim[1] = dim[1];
}
int width = tdim[1] + 3 / 4;
int height = tdim[0];
std::unique_ptr<half_t[]> imageData{};
if (tensor_data) {
imageData.reset(new half_t[width * height * 4]);
for (int h = 0; h < tdim[0]; h++) {
for (int w = 0; w < tdim[1]; w++) {
imageData[(h * width + w / 4) * 4 + (w % 4)] = Float2Half(tensor_data[h * tdim[1] + w]);
}
}
}
InitCLImage(context, width, height, imageData.get());
}
void InitCLImage(cl_context context, int width, int height, void *data) {
cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT};
cl_image_desc cid = {
.image_type = CL_MEM_OBJECT_IMAGE2D,
.image_width = width,
.image_height = height,
.image_depth = 1,
.image_array_size = 1,
.image_row_pitch = 0,
.image_slice_pitch = 0,
.num_mip_levels = 0,
.num_samples = 0,
// .buffer = nullptr
};
cid.buffer = nullptr;
cl_int err;
cl_image_ = clCreateImage(
context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0),
&cf, // const cl_image_format *image_format
&cid, // const cl_image_desc *image_desc
data, // void *host_ptr
&err
);
if (err != CL_SUCCESS) {
CL_CHECK_ERRORS(err);
PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error ");
}
}
void InitCLImage(cl_context context, float *tensor_data, const DDim &dim) {
DLOG << " tensor dim: " << dim;
// NCHW -> [W * (C+3)/4, H * N]
......@@ -160,17 +216,9 @@ class CLImage {
for (int h = 0; h < H; h++) {
size_t i2 = (i1 << 2) + c % 4;
for (int w = 0; w < W; w++) {
if (i2 >= width * height * 4) {
printf("%d > %d ----> %d, %d, %d, %d --- %d, %d, %d\n", i2,
width * height * 4, n, c, h, w, i0, i1, i2);
}
assert(i2 < width * height * 4);
imageData[i2] = Float2Half(*p);
i2 += 4;
p++;
// count++;
// DLOG<<count;
}
i1 += width;
}
......@@ -178,36 +226,7 @@ class CLImage {
i0 += width * H;
}
}
cl_int err;
DLOG << " image width: " << width;
DLOG << " image height: " << height;
cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT};
cl_image_desc cid = {
.image_type = CL_MEM_OBJECT_IMAGE2D,
.image_width = width,
.image_height = height,
.image_depth = 1,
.image_array_size = 1,
.image_row_pitch = 0,
.image_slice_pitch = 0,
.num_mip_levels = 0,
.num_samples = 0,
// .buffer = nullptr
};
cid.buffer = nullptr;
cl_image_ = clCreateImage(
context, CL_MEM_READ_WRITE | (imageData ? CL_MEM_COPY_HOST_PTR : 0),
&cf, // const cl_image_format *image_format
&cid, // const cl_image_desc *image_desc
reinterpret_cast<void *>(imageData.get()), // void *host_ptr
&err);
if (err != CL_SUCCESS) {
CL_CHECK_ERRORS(err);
PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error ");
}
InitCLImage(context, width, height, imageData.get());
}
bool initialized_ = false;
......
......@@ -34,8 +34,14 @@ void ReshapeKernel<GPU_CL, float>::Compute(const ReshapeParam<GPU_CL> &param) {
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
const auto &inputDim = input->dims();
const auto &outputDim = output->dims();
int dims[4] = {inputDim[0], inputDim[1], inputDim[2], inputDim[3]};
int odims[4] = {outputDim[0], outputDim[1], outputDim[2], outputDim[3]};
int dims[4] = {1, 1, 1, 1};
int odims[4] = {1, 1, 1, 1};
for (int i = 0; i < inputDim.size(); i++) {
dims[4-inputDim.size()+i] = inputDim[i];
}
for (int i = 0; i < outputDim.size(); i++) {
odims[4-outputDim.size()+i] = outputDim[i];
}
clSetKernelArg(kernel, 2, sizeof(int), dims);
clSetKernelArg(kernel, 3, sizeof(int), dims + 1);
clSetKernelArg(kernel, 4, sizeof(int), dims + 2);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册