diff --git a/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl b/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl index 10f39f9cf9549a6c1a5abe2af905f94f7355220e..8fba62e91f8f60f8d71c486b69a65cd61a192a5a 100644 --- a/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl @@ -2,11 +2,11 @@ __kernel void fetch(__private const int in_height, __private const int in_width, + __read_only image2d_t input, + __global float* out, __private const int size_ch, __private const int size_block, - __private const int size_batch, - __read_only image2d_t input, - __global float* out) { + __private const int size_batch) { const int in_c = get_global_id(0); const int in_w = get_global_id(1); const int in_nh = get_global_id(2); @@ -25,3 +25,22 @@ __kernel void fetch(__private const int in_height, out[index + size_ch * 2] = convert_float(in.z); out[index + size_ch * 3] = convert_float(in.w); } + +__kernel void fetch_2d(__private const int in_height, + __private const int in_width, + __read_only image2d_t input, + __global float* out) { + const int in_w = get_global_id(1); + const int in_h = get_global_id(2); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + half4 in = read_imageh(input, sampler, (int2)(in_w, in_h)); + + const int index = (in_h * in_width + in_w) * 4; + out[index] = convert_float(in.x); + out[index + 1] = convert_float(in.y); + out[index + 2] = convert_float(in.z); + out[index + 3] = convert_float(in.w); +} diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index 807f8833440529fd6be4198112edc8a9a2223823..23c5e8de4616c0fb014bd8c696c8eeb1b8ea06c1 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -20,7 +20,11 @@ namespace operators { template <> bool FetchKernel::Init(FetchParam *param) { - this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + if (param->InputX()->dims().size() <= 2) { + this->cl_helper_.AddKernel("fetch_2d", "fetch_kernel.cl"); + } else { + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + } auto *out = param->Out(); out->mutable_data(); return true; @@ -41,16 +45,15 @@ void FetchKernel::Compute(const FetchParam ¶m) { new_dims[4 - dim.size() + j] = dim[j]; } - size_t N, C, in_height, in_width; + size_t C, in_height, in_width; - N = new_dims[0]; C = new_dims[1]; in_height = new_dims[2]; - in_width = new_dims[3]; - - int size_ch = in_height * in_width; - int size_block = size_ch * 4; - int size_batch = size_ch * C; + if (dim.size() <= 2) { + in_width = param.InputX()->ImageWidth(); + } else { + in_width = new_dims[3]; + } CLTensor out_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); @@ -59,11 +62,16 @@ void FetchKernel::Compute(const FetchParam ¶m) { clSetKernelArg(kernel, 0, sizeof(int), &in_height); clSetKernelArg(kernel, 1, sizeof(int), &in_width); - clSetKernelArg(kernel, 2, sizeof(int), &size_ch); - clSetKernelArg(kernel, 3, sizeof(int), &size_block); - clSetKernelArg(kernel, 4, sizeof(int), &size_batch); - clSetKernelArg(kernel, 5, sizeof(cl_mem), &input); - clSetKernelArg(kernel, 6, sizeof(cl_mem), &outBuffer); + clSetKernelArg(kernel, 2, sizeof(cl_mem), &input); + clSetKernelArg(kernel, 3, sizeof(cl_mem), &outBuffer); + if (dim.size() > 2) { + int size_ch = in_height * in_width; + int size_block = size_ch * 4; + int size_batch = size_ch * C; + clSetKernelArg(kernel, 4, sizeof(int), &size_ch); + clSetKernelArg(kernel, 5, sizeof(int), &size_block); + clSetKernelArg(kernel, 6, sizeof(int), &size_batch); + } clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, default_work_size.data(), NULL, 0, NULL, NULL);