提交 57ad25ef 编写于 作者: Z zhaojiaying01

update fetch kernel

上级 4f07b27a
......@@ -2,11 +2,11 @@
__kernel void fetch(__private const int in_height,
__private const int in_width,
__read_only image2d_t input,
__global float* out,
__private const int size_ch,
__private const int size_block,
__private const int size_batch,
__read_only image2d_t input,
__global float* out) {
__private const int size_batch) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
......@@ -25,3 +25,22 @@ __kernel void fetch(__private const int in_height,
out[index + size_ch * 2] = convert_float(in.z);
out[index + size_ch * 3] = convert_float(in.w);
}
__kernel void fetch_2d(__private const int in_height,
__private const int in_width,
__read_only image2d_t input,
__global float* out) {
const int in_w = get_global_id(1);
const int in_h = get_global_id(2);
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, (int2)(in_w, in_h));
const int index = (in_h * in_width + in_w) * 4;
out[index] = convert_float(in.x);
out[index + 1] = convert_float(in.y);
out[index + 2] = convert_float(in.z);
out[index + 3] = convert_float(in.w);
}
......@@ -20,7 +20,11 @@ namespace operators {
template <>
bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
if (param->InputX()->dims().size() <= 2) {
this->cl_helper_.AddKernel("fetch_2d", "fetch_kernel.cl");
} else {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
}
auto *out = param->Out();
out->mutable_data<float>();
return true;
......@@ -41,16 +45,15 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
new_dims[4 - dim.size() + j] = dim[j];
}
size_t N, C, in_height, in_width;
size_t C, in_height, in_width;
N = new_dims[0];
C = new_dims[1];
in_height = new_dims[2];
in_width = new_dims[3];
int size_ch = in_height * in_width;
int size_block = size_ch * 4;
int size_batch = size_ch * C;
if (dim.size() <= 2) {
in_width = param.InputX()->ImageWidth();
} else {
in_width = new_dims[3];
}
CLTensor out_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
......@@ -59,11 +62,16 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
clSetKernelArg(kernel, 0, sizeof(int), &in_height);
clSetKernelArg(kernel, 1, sizeof(int), &in_width);
clSetKernelArg(kernel, 2, sizeof(int), &size_ch);
clSetKernelArg(kernel, 3, sizeof(int), &size_block);
clSetKernelArg(kernel, 4, sizeof(int), &size_batch);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 6, sizeof(cl_mem), &outBuffer);
clSetKernelArg(kernel, 2, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &outBuffer);
if (dim.size() > 2) {
int size_ch = in_height * in_width;
int size_block = size_ch * 4;
int size_batch = size_ch * C;
clSetKernelArg(kernel, 4, sizeof(int), &size_ch);
clSetKernelArg(kernel, 5, sizeof(int), &size_block);
clSetKernelArg(kernel, 6, sizeof(int), &size_batch);
}
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册