From 8a5cd68ce541e688b3833e051d36057993c7e378 Mon Sep 17 00:00:00 2001 From: yangfei Date: Fri, 23 Nov 2018 16:08:45 +0800 Subject: [PATCH] imp mobilelessd --- src/framework/cl/cl_image.cpp | 92 +++++- src/framework/cl/cl_image.h | 8 +- src/io/paddle_mobile.cpp | 68 ++-- src/operators/kernel/cl/box_coder_kernel.cpp | 48 ++- .../kernel/cl/cl_kernel/box_coder_kernel.cl | 147 +++++++++ .../kernel/cl/cl_kernel/feed_kernel.cl | 68 ++-- .../kernel/cl/cl_kernel/prior_box_kernel.cl | 2 + src/operators/kernel/cl/cl_kernel/softmax.cl | 47 ++- .../kernel/cl/cl_kernel/transpose_kernel.cl | 42 ++- src/operators/kernel/cl/feed_kernel.cpp | 34 +- src/operators/kernel/cl/fetch_kernel.cpp | 1 + .../kernel/cl/multiclass_nms_kernel.cpp | 310 +++++++++++++++++- src/operators/kernel/cl/softmax_kernel.cpp | 31 +- src/operators/kernel/cl/transpose_kernel.cpp | 65 ++++ src/operators/op_param.h | 4 +- 15 files changed, 858 insertions(+), 109 deletions(-) create mode 100644 src/operators/kernel/cl/cl_kernel/box_coder_kernel.cl diff --git a/src/framework/cl/cl_image.cpp b/src/framework/cl/cl_image.cpp index f32de0a614..d6cc52d69c 100644 --- a/src/framework/cl/cl_image.cpp +++ b/src/framework/cl/cl_image.cpp @@ -13,18 +13,98 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "framework/cl/cl_image.h" +#include "framework/cl/cl_tensor.h" namespace paddle_mobile { namespace framework { -void CLImageToTensor(CLImage *cl_image, Tensor *tensor, - cl_command_queue commandQueue) { - // TODO(yangfei): need imp +void CLImageToTensor(CLImage *cl_image, Tensor *tensor, cl_context context, + cl_command_queue commandQueue, cl_kernel kernel) { + tensor->mutable_data(); + const auto &dim = cl_image->dims(); + size_t new_dims[] = {1, 1, 1, 1}; + for (int j = 0; j < dim.size(); ++j) { + new_dims[4 - dim.size() + j] = dim[j]; + } + size_t C, in_height, in_width; + + C = new_dims[1]; + in_height = new_dims[2]; + in_width = new_dims[3]; + + CLTensor out_cl_tensor(context, commandQueue); + out_cl_tensor.Resize(tensor->dims()); + cl_mem outBuffer = out_cl_tensor.mutable_data(); + + auto input_image = cl_image->GetCLImage(); + + clSetKernelArg(kernel, 0, sizeof(int), &in_height); + clSetKernelArg(kernel, 1, sizeof(int), &in_width); + clSetKernelArg(kernel, 2, sizeof(cl_mem), &input_image); + clSetKernelArg(kernel, 3, sizeof(cl_mem), &outBuffer); + int size_ch = in_height * in_width; + int size_block = size_ch * 4; + int size_batch = size_ch * C; + clSetKernelArg(kernel, 4, sizeof(int), &size_ch); + clSetKernelArg(kernel, 5, sizeof(int), &size_block); + clSetKernelArg(kernel, 6, sizeof(int), &size_batch); + clSetKernelArg(kernel, 7, sizeof(int), &C); + size_t global_work_size[3] = {(new_dims[1] + 3) / 4, new_dims[3], + new_dims[0] * new_dims[2]}; + clEnqueueNDRangeKernel(commandQueue, kernel, 3, NULL, global_work_size, NULL, + 0, NULL, NULL); + memcpy(tensor->data(), out_cl_tensor.Data(), + tensor->memory_size()); } -void TensorToCLImage(const Tensor *tensor, CLImage *cl_image, - cl_command_queue commandQueue) { - // TODO(yangfei): need imp +void TensorToCLImage(Tensor *tensor, CLImage *cl_image, cl_context context, + cl_command_queue commandQueue, cl_kernel kernel) { + const auto &dim = cl_image->dims(); + size_t new_dims[] = {1, 1, 1, 1}; + for (int j = 0; j < dim.size(); ++j) { + new_dims[4 - dim.size() + j] = dim[j]; + } + cl_int status; + auto output = cl_image; + const Tensor *input = tensor; + const float *input_data = input->data(); + auto output_image = output->GetCLImage(); + const int out_C = new_dims[1]; + const int out_H = new_dims[2]; + const int out_W = new_dims[3]; + const int Stride2 = out_C * out_H * out_W; + const int Stride1 = out_H * out_W; + const int Stride0 = out_W; + DLOG << out_C; + DLOG << out_H; + DLOG << out_W; + CLTensor input_cl_tensor(context, commandQueue); + input_cl_tensor.Resize(input->dims()); + cl_mem inputBuffer = input_cl_tensor.mutable_with_data(input_data); + + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &out_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(cl_int), &Stride0); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(cl_int), &Stride1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(cl_int), &Stride2); + CL_CHECK_ERRORS(status); + + size_t global_work_size[3] = {(new_dims[1] + 3) / 4, new_dims[3], + new_dims[0] * new_dims[2]}; + status = clEnqueueNDRangeKernel(commandQueue, kernel, 3, NULL, + global_work_size, NULL, 0, NULL, NULL); + + CL_CHECK_ERRORS(status); } #ifdef PADDLE_MOBILE_DEBUG diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index 0c19661ede..f94eba187f 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -222,11 +222,11 @@ class CLImage { CLImageConverterBase *image_converter_ = nullptr; }; -void TensorToCLImage(Tensor *tensor, CLImage *image, - cl_command_queue commandQueue); +void TensorToCLImage(Tensor *tensor, CLImage *image, cl_context context, + cl_command_queue commandQueue, cl_kernel kernel); -void CLImageToTensor(CLImage *image, Tensor *tensor, - cl_command_queue commandQueue); +void CLImageToTensor(CLImage *image, Tensor *tensor, cl_context context, + cl_command_queue commandQueue, cl_kernel kernel); #ifdef PADDLE_MOBILE_DEBUG Print &operator<<(Print &printer, const CLImage &image); diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 6a773da00f..4b50f15a86 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -143,10 +143,12 @@ double PaddleMobile::GetPredictTime() { int t1 = 1; int t2 = 1; for (int i = 0; i < m * k; ++i) { - a[i] = t1 + rand() % t2; + unsigned int seed = 100; + a[i] = t1 + rand_r(&seed) % t2; } for (int i = 0; i < k * n; ++i) { - b[i] = t1 + rand() % t2; + unsigned int seed = 200; + b[i] = t1 + rand_r(&seed) % t2; } paddle_mobile::operators::math::Gemm gemm; auto time1 = paddle_mobile::time(); @@ -215,13 +217,13 @@ double PaddleMobile::GetPredictTime() { cl_int status; cl_uint nPlatform; clGetPlatformIDs(0, NULL, &nPlatform); - cl_platform_id *listPlatform = - (cl_platform_id *)malloc(nPlatform * sizeof(cl_platform_id)); + cl_platform_id *listPlatform = reinterpret_cast( + malloc(nPlatform * sizeof(cl_platform_id))); clGetPlatformIDs(nPlatform, listPlatform, NULL); cl_uint nDevice = 0; clGetDeviceIDs(listPlatform[0], CL_DEVICE_TYPE_GPU, 0, NULL, &nDevice); cl_device_id *listDevice = - (cl_device_id *)malloc(nDevice * sizeof(cl_device_id)); + reinterpret_cast(malloc(nDevice * sizeof(cl_device_id))); clGetDeviceIDs(listPlatform[0], CL_DEVICE_TYPE_GPU, nDevice, listDevice, NULL); cl_context context = @@ -277,41 +279,66 @@ double PaddleMobile::GetPredictTime() { clBuildProgram(program, 0, 0, path1.c_str(), NULL, NULL); cl_kernel kernel = clCreateKernel(program, "feed", &status); + int out_H = 224; + int out_W = 224; + int out_C = 3; + int Stride2 = out_C * out_H * out_W; + int Stride1 = out_H * out_W; + int Stride0 = out_W; status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer); CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_input_image); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 2, sizeof(cl_int), &input_w); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &out_H); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 3, sizeof(cl_int), &input_h); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &out_W); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(cl_int), &Stride0); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(cl_int), &Stride1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(cl_int), &Stride2); CL_CHECK_ERRORS(status); - size_t global_work_size[2] = {input_w, input_h}; + size_t global_work_size[3] = {1, 224, 224}; // cl_event out_event = param.Out()->GetClEvent(); - status = clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size, + status = clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); + out_H = 3; + out_W = 3; + out_C = 3; + Stride2 = out_C * out_H * out_W; + Stride1 = out_H * out_W; + Stride0 = out_W; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &filterBuffer); CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_filter_image); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 2, sizeof(cl_int), &filter_w); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &out_W); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 3, sizeof(cl_int), &filter_h); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &out_C); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c); + status = clSetKernelArg(kernel, 5, sizeof(cl_int), &Stride0); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(cl_int), &Stride1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(cl_int), &Stride2); CL_CHECK_ERRORS(status); - size_t global_work_size1[2] = {filter_w, filter_h}; + size_t global_work_size1[3] = {1, 3, 96}; // cl_event out_event = param.Out()->GetClEvent(); - status = clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_work_size1, + status = clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size1, NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); @@ -378,13 +405,16 @@ double PaddleMobile::GetPredictTime() { auto time2 = paddle_mobile::time(); paddle_mobile::memory::Free(input); paddle_mobile::memory::Free(filter); - return paddle_mobile::time_diff(time1, time2); + if (status == CL_SUCCESS) { + return paddle_mobile::time_diff(time1, time2); + } else { + return -1; + } } template int PaddleMobile::readText( const char *kernelPath, - char **pcode) // 读取文本文件放入 pcode,返回字符串长度 -{ + char **pcode) { // 读取文本文件放入 pcode,返回字符串长度 FILE *fp; int size; // printf(" File: %s\n", kernelPath); @@ -402,7 +432,7 @@ int PaddleMobile::readText( return -1; } rewind(fp); - if ((*pcode = (char *)malloc(size + 1)) == NULL) { + if ((*pcode = reinterpret_cast(malloc(size + 1))) == NULL) { printf(" Allocate space failed\n"); return -1; } diff --git a/src/operators/kernel/cl/box_coder_kernel.cpp b/src/operators/kernel/cl/box_coder_kernel.cpp index 582f6131bf..b98435f9b0 100644 --- a/src/operators/kernel/cl/box_coder_kernel.cpp +++ b/src/operators/kernel/cl/box_coder_kernel.cpp @@ -20,13 +20,57 @@ namespace paddle_mobile { namespace operators { template <> -bool BoxCoderKernel::Init(BoxCoderParam *param) { +bool BoxCoderKernel::Init(BoxCoderParam* param) { + if (param->CodeType() == "decode_center_size") { + this->cl_helper_.AddKernel("box_decoder", "box_coder_kernel.cl"); + } return true; } template <> void BoxCoderKernel::Compute( - const BoxCoderParam ¶m) {} + const BoxCoderParam& param) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.OutputBox()); + const auto* input_priorbox = param.InputPriorBox(); + const auto* input_priorboxvar = param.InputPriorBoxVar(); + const auto* input_targetbox = param.InputTargetBox(); + const auto& code_type = param.CodeType(); + if (code_type == "decode_center_size") { + auto prior_box_image = input_priorbox->GetCLImage(); + auto prior_box_var_image = input_priorboxvar->GetCLImage(); + auto target_box_image = input_targetbox->GetCLImage(); + auto output_image = param.OutputBox()->GetCLImage(); + auto& outputDim = param.OutputBox()->dims(); + int new_dims[4] = {1, 1, 1, 1}; + for (int i = 0; i < outputDim.size(); i++) { + new_dims[4 - outputDim.size() + i] = outputDim[i]; + } + int out_C = new_dims[1]; + int out_H = new_dims[2]; + DLOG << "out_C=" << out_C; + DLOG << "out_H=" << out_H; + DLOG << "default_work_size=" << default_work_size; + cl_int status; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &prior_box_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &prior_box_var_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &target_box_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(int), &out_H); + CL_CHECK_ERRORS(status); + size_t global_work_size[2] = {default_work_size[0], default_work_size[2]}; + status = + clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, + NULL, global_work_size, NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + } +} } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/cl/cl_kernel/box_coder_kernel.cl b/src/operators/kernel/cl/cl_kernel/box_coder_kernel.cl new file mode 100644 index 0000000000..60000c994e --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/box_coder_kernel.cl @@ -0,0 +1,147 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void box_decoder(__read_only image2d_t prior_box_image, + __read_only image2d_t prior_box_var_image, + __read_only image2d_t target_box_image, + __write_only image2d_t output_image, + __private const int out_C, + __private const int out_H + ){ + const int out_c = get_global_id(0); + const int out_nh = get_global_id(1); + const int out_h = out_nh%out_H; + const int out_n = 1; + + const int prior_box_n = 1; + const int prior_box_c = 0; + const int prior_box_h = out_h; + + + const int prior_box_var_n = 1; + const int prior_box_var_c = 0; + const int prior_box_var_h = out_h; + + const int target_box_n = 1; + const int target_box_c = out_c; + const int target_box_h = out_h; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + int2 prior_box_pos; + int2 prior_box_var_pos; + int2 target_box_pos; + int2 output_pos; + + prior_box_pos.x = prior_box_c * 4; + prior_box_pos.y = prior_box_n * prior_box_h; + + prior_box_var_pos.x = prior_box_var_c * 4; + prior_box_var_pos.y = prior_box_var_n * prior_box_var_h; + + target_box_pos.x = target_box_c * 4; + target_box_pos.y = target_box_n * target_box_h; + + output_pos.x = out_c * 4; + output_pos.y = out_n * out_h; + + half4 prior_box_input[4]; + half4 prior_box_var_input[4]; + half4 target_box_input[4]; + + prior_box_input[0] = read_imageh(prior_box_image, sampler,(int2)(prior_box_pos.x + 0,prior_box_pos.y)); + prior_box_input[1] = read_imageh(prior_box_image, sampler,(int2)(prior_box_pos.x + 1,prior_box_pos.y)); + prior_box_input[2] = read_imageh(prior_box_image, sampler,(int2)(prior_box_pos.x + 2,prior_box_pos.y)); + prior_box_input[3] = read_imageh(prior_box_image, sampler,(int2)(prior_box_pos.x + 3,prior_box_pos.y)); + + prior_box_var_input[0] = read_imageh(prior_box_var_image, sampler,(int2)(prior_box_var_pos.x + 0,prior_box_var_pos.y)); + prior_box_var_input[1] = read_imageh(prior_box_var_image, sampler,(int2)(prior_box_var_pos.x + 1,prior_box_var_pos.y)); + prior_box_var_input[2] = read_imageh(prior_box_var_image, sampler,(int2)(prior_box_var_pos.x + 2,prior_box_var_pos.y)); + prior_box_var_input[3] = read_imageh(prior_box_var_image, sampler,(int2)(prior_box_var_pos.x + 3,prior_box_var_pos.y)); + + + + target_box_input[0] = read_imageh(target_box_image, sampler,(int2)(target_box_pos.x + 0,target_box_pos.y)); + target_box_input[1] = read_imageh(target_box_image, sampler,(int2)(target_box_pos.x + 1,target_box_pos.y)); + target_box_input[2] = read_imageh(target_box_image, sampler,(int2)(target_box_pos.x + 2,target_box_pos.y)); + target_box_input[3] = read_imageh(target_box_image, sampler,(int2)(target_box_pos.x + 3,target_box_pos.y)); + + half prior_box_width = prior_box_input[2].x - prior_box_input[0].x; + half prior_box_height = prior_box_input[3].x - prior_box_input[1].x; + half prior_box_center_x = (prior_box_input[2].x + prior_box_input[0].x)/(half)2; + half prior_box_center_y = (prior_box_input[3].x + prior_box_input[1].x)/(half)2; + + half4 target_box_center_x; + half4 target_box_center_y; + half4 target_box_width; + half4 target_box_height; + half4 output[4]; + + output[0] = 0.0f; + output[1] = 0.0f; + output[2] = 0.0f; + output[3] = 0.0f; + + target_box_center_x.x = prior_box_var_input[0].x * target_box_input[0].x * prior_box_width + prior_box_center_x; + target_box_center_y.x = prior_box_var_input[1].x * target_box_input[1].x * prior_box_height + prior_box_center_y; + target_box_width.x = exp(prior_box_var_input[2].x * target_box_input[2].x) * prior_box_width; + target_box_height.x = exp(prior_box_var_input[3].x * target_box_input[3].x) * prior_box_height; + + output[0].x = target_box_center_x.x - target_box_width.x/(half)2; + output[1].x = target_box_center_y.x - target_box_height.x/(half)2; + output[2].x = target_box_center_x.x + target_box_width.x/(half)2; + output[3].x = target_box_center_y.x + target_box_height.x/(half)2; + + if(out_C - out_c * 4 >= 2){ + target_box_center_x.y = prior_box_var_input[0].x * target_box_input[0].y * prior_box_width + prior_box_center_x; + target_box_center_y.y = prior_box_var_input[1].x * target_box_input[1].y * prior_box_height + prior_box_center_y; + target_box_width.y = exp(prior_box_var_input[2].x * target_box_input[2].y) * prior_box_width; + target_box_height.y = exp(prior_box_var_input[3].x * target_box_input[3].y) * prior_box_height; + output[0].y = target_box_center_x.y - target_box_width.y/(half)2; + output[1].y = target_box_center_y.y - target_box_height.y/(half)2; + output[2].y = target_box_center_x.y + target_box_width.y/(half)2; + output[3].y = target_box_center_y.y + target_box_height.y/(half)2; + + } + if(out_C - out_c * 4 >= 3){ + target_box_center_x.z = prior_box_var_input[0].x * target_box_input[0].z * prior_box_width + prior_box_center_x; + target_box_center_y.z = prior_box_var_input[1].x * target_box_input[1].z * prior_box_height + prior_box_center_y; + target_box_width.z = exp(prior_box_var_input[2].x * target_box_input[2].z) * prior_box_width; + target_box_height.z = exp(prior_box_var_input[3].x * target_box_input[3].z) * prior_box_height; + output[0].z = target_box_center_x.z - target_box_width.z/(half)2; + output[1].z = target_box_center_y.z - target_box_height.z/(half)2; + output[2].z = target_box_center_x.z + target_box_width.z/(half)2; + output[3].z = target_box_center_y.z + target_box_height.z/(half)2; + } + if(out_C - out_c * 4 >= 4){ + target_box_center_x.w = prior_box_var_input[0].x * target_box_input[0].w * prior_box_width + prior_box_center_x; + target_box_center_y.w = prior_box_var_input[1].x * target_box_input[1].w * prior_box_height + prior_box_center_y; + target_box_width.w = exp(prior_box_var_input[2].x * target_box_input[2].w) * prior_box_width; + target_box_height.w = exp(prior_box_var_input[3].x * target_box_input[3].w) * prior_box_height; + output[0].w = target_box_center_x.w - target_box_width.w/(half)2; + output[1].w = target_box_center_y.w - target_box_height.w/(half)2; + output[2].w = target_box_center_x.w + target_box_width.w/(half)2; + output[3].w = target_box_center_y.w + target_box_height.w/(half)2; + } + + + write_imageh(output_image, (int2)(output_pos.x + 0, output_pos.y), output[0]); + write_imageh(output_image, (int2)(output_pos.x + 1, output_pos.y), output[1]); + write_imageh(output_image, (int2)(output_pos.x + 2, output_pos.y), output[2]); + write_imageh(output_image, (int2)(output_pos.x + 3, output_pos.y), output[3]); + +} \ No newline at end of file diff --git a/src/operators/kernel/cl/cl_kernel/feed_kernel.cl b/src/operators/kernel/cl/cl_kernel/feed_kernel.cl index 200a221c9b..bb661f3cf7 100644 --- a/src/operators/kernel/cl/cl_kernel/feed_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/feed_kernel.cl @@ -13,26 +13,50 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void feed(__global float *in, __write_only image2d_t outputImage,int h,int w,int c) - { - int i = get_global_id(0); - int j = get_global_id(1); - half4 pixel; - pixel.x = convert_half(in[(i * w + j)]); - if(c>=2){ - pixel.y = convert_half(in[h * w + (i * w + j)]); - }else{ - pixel.y = 0.0; - } - if(c>=3){ - pixel.z = convert_half(in[2 * h * w + (i * w + j)]); - }else{ - pixel.z = 0.0; - } - pixel.w = 0.0; - int2 coords; - coords.x = j; - coords.y = i; - - write_imageh(outputImage,coords,pixel); +__kernel void feed(__global float *in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_C, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2){ + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh/out_H; + const int out_h = out_nh%out_H; + + const int in_n = out_n; + const int in_c0 = out_c * 4 + 0; + const int in_c1 = out_c * 4 + 1; + const int in_c2 = out_c * 4 + 2; + const int in_c3 = out_c * 4 + 3; + const int in_h = out_h; + const int in_w = out_w; + + + int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n * Stride2 + in_c3 * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + half4 output = (half4)0.0f; + output.x = convert_half(in[input_pos0]); + if(out_C - 4 * out_c>=2){ + output.y = convert_half(in[input_pos1]); + } + if(out_C - 4 * out_c>=3){ + output.z = convert_half(in[input_pos2]); + } + if(out_C - 4 * out_c>=4){ + output.w = convert_half(in[input_pos3]); + } + write_imageh(output_image, output_pos, output); + } diff --git a/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl b/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl index 699d381ce6..886f62df68 100644 --- a/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl @@ -107,11 +107,13 @@ __kernel void prior_box(__private const int global_size_dim0, output[2] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[2]),(half4)(1.0f, 1.0f, 1.0f, 1.0f)); output[3] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[3]),(half4)(1.0f, 1.0f, 1.0f, 1.0f)); } + /* if(output_pos.x == 0 && output_pos.y == 1){ float4 out = (float4)(output[0].x, output[1].x, output[2].x, output[3].x); printf("output = %v4hlf \n", out); } + */ write_imageh(output_boxes, (int2)(output_pos.x + 0, output_pos.y), output[0]); write_imageh(output_boxes, (int2)(output_pos.x + 1, output_pos.y), output[1]); diff --git a/src/operators/kernel/cl/cl_kernel/softmax.cl b/src/operators/kernel/cl/cl_kernel/softmax.cl index 215ec69fc2..a1fa014e00 100644 --- a/src/operators/kernel/cl/cl_kernel/softmax.cl +++ b/src/operators/kernel/cl/cl_kernel/softmax.cl @@ -16,35 +16,46 @@ limitations under the License. */ __kernel void softmax(__read_only image2d_t input_image, __write_only image2d_t output_image, - __private const int group + __private const int out_W ) { const int out_c = get_global_id(0); // block index const int out_w = get_global_id(1); // index in one block const int out_nh = get_global_id(2); + const int in_c = out_c; + const int in_w = out_w; + const int in_nh = out_nh; - const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | - CLK_ADDRESS_CLAMP | - CLK_FILTER_NEAREST; + int2 input_pos; + int2 output_pos; - half maxv = 0.0f; - for (int i = 0; i < group; ++i) { - half4 temp = read_imageh(input_image, sampler, (int2)(i, 0)); - maxv = max(maxv, max(temp.x, max(temp.y, max(temp.z, temp.w)))); - } + input_pos.x = in_c * out_W + in_w; + input_pos.y = in_nh; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; - half4 rsum = (half4)(0.0f); - for (int i = 0; i < group; ++i) { - half4 r = read_imageh(input_image, sampler, (int2)(i, 0)); - rsum += convert_half4(exp(convert_float4(r - maxv))); - } + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + half4 input_max = 0.0f; + half4 input_tmp; + for(int i=0;i::Init(FeedParam *param) { template <> void FeedKernel::Compute(const FeedParam ¶m) { auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Out())); cl_int status; param.Out()->InitEmptyImage(cl_helper_.CLContext(), cl_helper_.CLCommandQueue(), param.Out()->dims()); @@ -35,10 +36,13 @@ void FeedKernel::Compute(const FeedParam ¶m) { // DLOG << *input; const float *input_data = input->data(); int numel = input->numel(); - cl_mem cl_image = output->GetCLImage(); - int c = input->dims()[1]; - int height = output->dims()[2]; - int width = output->dims()[3]; + cl_mem output_image = output->GetCLImage(); + const int out_C = output->dims()[1]; + const int out_H = output->dims()[2]; + const int out_W = output->dims()[3]; + const int Stride2 = out_C * out_H * out_W; + const int Stride1 = out_H * out_W; + const int Stride0 = out_W; CLTensor input_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); input_cl_tensor.Resize(input->dims()); @@ -46,21 +50,25 @@ void FeedKernel::Compute(const FeedParam ¶m) { status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &cl_image); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 2, sizeof(cl_int), &width); + status = clSetKernelArg(kernel, 2, sizeof(cl_int), &out_H); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 3, sizeof(cl_int), &height); + status = clSetKernelArg(kernel, 3, sizeof(cl_int), &out_W); CL_CHECK_ERRORS(status); - status = clSetKernelArg(kernel, 4, sizeof(cl_int), &c); + status = clSetKernelArg(kernel, 4, sizeof(cl_int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(cl_int), &Stride0); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(cl_int), &Stride1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(cl_int), &Stride2); CL_CHECK_ERRORS(status); - size_t global_work_size[2] = {width, height}; - - // cl_event out_event = param.Out()->GetClEvent(); + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); - status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, - NULL, global_work_size, NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); } diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index ded90ff43f..e13fbcaed6 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -37,6 +37,7 @@ void FetchKernel::Compute(const FetchParam ¶m) { auto input = param.InputX()->GetCLImage(); auto *out = param.Out(); + out->Resize(param.InputX()->dims()); out->mutable_data(); const auto &dim = param.InputX()->dims(); size_t new_dims[] = {1, 1, 1, 1}; diff --git a/src/operators/kernel/cl/multiclass_nms_kernel.cpp b/src/operators/kernel/cl/multiclass_nms_kernel.cpp index e7bf02cde4..31ccdc0df5 100644 --- a/src/operators/kernel/cl/multiclass_nms_kernel.cpp +++ b/src/operators/kernel/cl/multiclass_nms_kernel.cpp @@ -15,19 +15,323 @@ limitations under the License. */ #ifdef MULTICLASSNMS_OP #include "operators/kernel/multiclass_nms_kernel.h" - +#include "operators/math/poly_util.h" namespace paddle_mobile { namespace operators { template <> bool MultiClassNMSKernel::Init( - MultiClassNMSParam *param) { + MultiClassNMSParam* param) { + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); return true; } +template +bool SortScorePairDescend(const std::pair& pair1, + const std::pair& pair2) { + return pair1.first > pair2.first; +} + +template +static inline void GetMaxScoreIndex( + const std::vector& scores, const T threshold, int top_k, + std::vector>* sorted_indices) { + for (size_t i = 0; i < scores.size(); ++i) { + if (scores[i] > threshold) { + sorted_indices->push_back(std::make_pair(scores[i], i)); + } + } + // Sort the score pair according to the scores in descending order + std::stable_sort(sorted_indices->begin(), sorted_indices->end(), + SortScorePairDescend); + // Keep top_k scores if needed. + if (top_k > -1 && top_k < static_cast(sorted_indices->size())) { + sorted_indices->resize(top_k); + } +} + +template +static inline T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const T* box1, const T* box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + const T inter_w = inter_xmax - inter_xmin; + const T inter_h = inter_ymax - inter_ymin; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +static inline T PolyIoU(const T* box1, const T* box2, const size_t box_size, + const bool normalized) { + T bbox1_area = math::PolyArea(box1, box_size, normalized); + T bbox2_area = math::PolyArea(box2, box_size, normalized); + T inter_area = math::PolyOverlapArea(box1, box2, box_size, normalized); + if (bbox1_area == 0 || bbox2_area == 0 || inter_area == 0) { + // If coordinate values are is invalid + // if area size <= 0, return 0. + return static_cast(0.); + } else { + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +static inline void NMSFast(const framework::Tensor& bbox, + const framework::Tensor& scores, + const T score_threshold, const T nms_threshold, + const T eta, const int64_t top_k, + std::vector* selected_indices) { + // The total boxes for each instance. + int64_t num_boxes = bbox.dims()[0]; + // 4: [xmin ymin xmax ymax] + int64_t box_size = bbox.dims()[1]; + + std::vector scores_data(num_boxes); + std::copy_n(scores.data(), num_boxes, scores_data.begin()); + std::vector> sorted_indices; + GetMaxScoreIndex(scores_data, score_threshold, top_k, &sorted_indices); + + selected_indices->clear(); + T adaptive_threshold = nms_threshold; + const T* bbox_data = bbox.data(); + + while (sorted_indices.size() != 0) { + const int idx = sorted_indices.front().second; + bool keep = true; + for (size_t k = 0; k < selected_indices->size(); ++k) { + if (keep) { + const int kept_idx = (*selected_indices)[k]; + T overlap = T(0.); + if (box_size == 4) { + overlap = JaccardOverlap(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, true); + } else { + overlap = PolyIoU(bbox_data + idx * box_size, + bbox_data + kept_idx * box_size, box_size, true); + } + keep = overlap <= adaptive_threshold; + } else { + break; + } + } + if (keep) { + selected_indices->push_back(idx); + } + sorted_indices.erase(sorted_indices.begin()); + if (keep && eta < 1 && adaptive_threshold > 0.5) { + adaptive_threshold *= eta; + } + } +} + +template +void MultiClassNMS(const framework::Tensor& scores, + const framework::Tensor& bboxes, + std::map>* indices, int* num_nmsed_out, + const int& background_label, const int& nms_top_k, + const int& keep_top_k, const T& nms_threshold, + const T& nms_eta, const T& score_threshold) { + int64_t class_num = scores.dims()[0]; + int64_t predict_dim = scores.dims()[1]; + int num_det = 0; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + framework::Tensor score = scores.Slice(c, c + 1); + /// [c] is key + NMSFast(bboxes, score, score_threshold, nms_threshold, nms_eta, + nms_top_k, &((*indices)[c])); + num_det += (*indices)[c].size(); + } + + *num_nmsed_out = num_det; + const T* scores_data = scores.data(); + if (keep_top_k > -1 && num_det > keep_top_k) { + std::vector>> score_index_pairs; + for (const auto& it : *indices) { + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + const std::vector& label_indices = it.second; + for (size_t j = 0; j < label_indices.size(); ++j) { + int idx = label_indices[j]; + // PADDLE_ENFORCE_LT(idx, predict_dim); + score_index_pairs.push_back( + std::make_pair(sdata[idx], std::make_pair(label, idx))); + } + } + // Keep top k results per image. + std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(), + SortScorePairDescend>); + score_index_pairs.resize(keep_top_k); + + // Store the new indices. + std::map> new_indices; + for (size_t j = 0; j < score_index_pairs.size(); ++j) { + int label = score_index_pairs[j].second.first; + int idx = score_index_pairs[j].second.second; + new_indices[label].push_back(idx); + } + new_indices.swap(*indices); + *num_nmsed_out = keep_top_k; + } +} + +template +void MultiClassOutput(const framework::Tensor& scores, + const framework::Tensor& bboxes, + const std::map>& selected_indices, + framework::Tensor* outs) { + int predict_dim = scores.dims()[1]; + int box_size = bboxes.dims()[1]; + int out_dim = bboxes.dims()[1] + 2; + auto* scores_data = scores.data(); + auto* bboxes_data = bboxes.data(); + auto* odata = outs->data(); + + int count = 0; + for (const auto& it : selected_indices) { + /// one batch + int label = it.first; + const T* sdata = scores_data + label * predict_dim; + const std::vector& indices = it.second; + for (size_t j = 0; j < indices.size(); ++j) { + int idx = indices[j]; + const T* bdata = bboxes_data + idx * box_size; + odata[count * out_dim] = label; // label + odata[count * out_dim + 1] = sdata[idx]; // score + // xmin, ymin, xmax, ymax + std::memcpy(odata + count * out_dim + 2, bdata, box_size * sizeof(T)); + count++; + } + } +} + +template +void MultiClassNMSCompute(const MultiClassNMSParam& param, + cl_context context, cl_command_queue commandQueue, + cl_kernel kernel0, cl_kernel kernel1) { + auto* input_bboxes_image = param.InputBBoxes(); + auto& input_bboxes_dims = input_bboxes_image->dims(); + Tensor* input_bboxes = new Tensor(); + input_bboxes->Resize(input_bboxes_dims); + input_bboxes->mutable_data(); + DLOG << "yangfei20"; + framework::CLImageToTensor(input_bboxes_image, input_bboxes, context, + commandQueue, kernel0); + DLOG << "yangfei20"; + auto* input_scores_image = param.InputScores(); + auto& input_scores_dims = input_scores_image->dims(); + + Tensor* input_scores = new Tensor(); + input_scores->Resize(input_scores_dims); + input_scores->mutable_data(); + framework::CLImageToTensor(input_scores_image, input_scores, context, + commandQueue, kernel0); + DLOG << "yangfei20"; + auto outs_image = param.Out(); + Tensor* outs = new Tensor(); + outs->Resize(outs_image->dims()); + outs->mutable_data(); + DLOG << *input_bboxes; + DLOG << *input_scores; + DLOG << *outs; + auto background_label = param.BackGroundLabel(); + auto nms_top_k = param.NMSTopK(); + auto keep_top_k = param.KeepTopK(); + auto nms_threshold = param.NMSThreshold(); + auto nms_eta = param.NMSEta(); + auto score_threshold = param.ScoreThreshold(); + int64_t batch_size = input_scores_dims[0]; + int64_t class_num = input_scores_dims[1]; + int64_t predict_dim = input_scores_dims[2]; + int64_t box_dim = input_bboxes_dims[2]; + + std::vector>> all_indices; + std::vector batch_starts = {0}; + for (int64_t i = 0; i < batch_size; ++i) { + framework::Tensor ins_score = input_scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + + framework::Tensor ins_boxes = input_bboxes->Slice(i, i + 1); + ins_boxes.Resize({predict_dim, box_dim}); + + std::map> indices; + int num_nmsed_out = 0; + MultiClassNMS(ins_score, ins_boxes, &indices, &num_nmsed_out, + background_label, nms_top_k, keep_top_k, nms_threshold, + nms_eta, score_threshold); + all_indices.push_back(indices); + batch_starts.push_back(batch_starts.back() + num_nmsed_out); + } + + int num_kept = batch_starts.back(); + if (num_kept == 0) { + float* od = outs->mutable_data({1}); + od[0] = -1; + } else { + int64_t out_dim = box_dim + 2; + outs->mutable_data({num_kept, out_dim}); + for (int64_t i = 0; i < batch_size; ++i) { + framework::Tensor ins_score = input_scores->Slice(i, i + 1); + ins_score.Resize({class_num, predict_dim}); + + framework::Tensor ins_boxes = input_bboxes->Slice(i, i + 1); + ins_boxes.Resize({predict_dim, box_dim}); + + int64_t s = batch_starts[i]; + int64_t e = batch_starts[i + 1]; + if (e > s) { + framework::Tensor out = outs->Slice(s, e); + MultiClassOutput(ins_score, ins_boxes, all_indices[i], &out); + } + } + } + DLOG << "yangfei20"; + outs_image->InitEmptyImage(context, commandQueue, outs->dims()); + framework::TensorToCLImage(outs, outs_image, context, commandQueue, kernel1); + DLOG << *outs; + delete (input_bboxes); + delete (input_scores); + delete (outs); + DLOG << "yangfei20"; +} template <> void MultiClassNMSKernel::Compute( - const MultiClassNMSParam ¶m) {} + const MultiClassNMSParam& param) { + auto kernel0 = this->cl_helper_.KernelAt(0); + auto kernel1 = this->cl_helper_.KernelAt(1); + MultiClassNMSCompute(param, this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), kernel0, + kernel1); +} } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/cl/softmax_kernel.cpp b/src/operators/kernel/cl/softmax_kernel.cpp index 22e6672ee4..6447b68d33 100644 --- a/src/operators/kernel/cl/softmax_kernel.cpp +++ b/src/operators/kernel/cl/softmax_kernel.cpp @@ -33,31 +33,24 @@ void SoftmaxKernel::Compute(const SoftmaxParam ¶m) { auto *output = param.Out(); auto inputImage = input->GetCLImage(); auto outputImage = output->GetCLImage(); + const auto &outputDim = output->dims(); - int group = output->ImageWidth(); + int dims[4] = {1, 1, 1, 1}; + + for (int i = 0; i < outputDim.size(); i++) { + dims[4 - outputDim.size() + i] = outputDim[i]; + } + + const int out_W = dims[3]; cl_int status; status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); + CL_CHECK_ERRORS(status); status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); - status = clSetKernelArg(kernel, 2, sizeof(int), &group); - - // const auto &inputDim = input->dims(); - // - // int dims[4] = {1, 1, 1, 1}; - // - // for (int i = 0; i < inputDim.size(); i++) { - // dims[4 - inputDim.size() + i] = inputDim[i]; - // } - // - // clSetKernelArg(kernel, 2, sizeof(int), &dims); - // clSetKernelArg(kernel, 3, sizeof(int), &dims[1]); - // clSetKernelArg(kernel, 4, sizeof(int), &dims[2]); - // clSetKernelArg(kernel, 5, sizeof(int), &dims[3]); - - // cl_event out_event = param.Out()->GetClEvent(); - // cl_event wait_event = param.InputX()->GetClEvent(); - + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &out_W); + CL_CHECK_ERRORS(status); status = clEnqueueNDRangeKernel( this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); diff --git a/src/operators/kernel/cl/transpose_kernel.cpp b/src/operators/kernel/cl/transpose_kernel.cpp index b5be025a15..d3133449b9 100644 --- a/src/operators/kernel/cl/transpose_kernel.cpp +++ b/src/operators/kernel/cl/transpose_kernel.cpp @@ -22,6 +22,8 @@ template <> bool TransposeKernel::Init(TransposeParam *param) { if (param->Out()->dims().size() == 4) { this->cl_helper_.AddKernel("transpose_4d", "transpose_kernel.cl"); + } else if (param->Out()->dims().size() < 4) { + this->cl_helper_.AddKernel("transpose", "transpose_kernel.cl"); } return true; } @@ -60,6 +62,69 @@ void TransposeKernel::Compute( this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, default_work_size.data(), NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); + } else if (param.Out()->dims().size() == 3) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out()); + int out_C = param.Out()->dims()[0]; + int out_H = param.Out()->dims()[1]; + int out_W = param.Out()->dims()[2]; + int in_W = param.InputX()->dims()[2]; + auto output_image = param.Out()->GetCLImage(); + auto input_image = param.InputX()->GetCLImage(); + DLOG << "out_C=" << out_C; + DLOG << "out_H=" << out_H; + DLOG << "out_W=" << out_W; + DLOG << "in_C=" << in_W; + DLOG << "default_work_size=" << default_work_size; + cl_int status; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(int), &out_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(int), &in_W); + CL_CHECK_ERRORS(status); + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), + NULL, default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + + } else if (param.Out()->dims().size() == 2) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out()); + int out_C = 1; + int out_H = param.Out()->dims()[0]; + int out_W = param.Out()->dims()[1]; + int in_W = param.InputX()->dims()[1]; + auto output_image = param.Out()->GetCLImage(); + auto input_image = param.InputX()->GetCLImage(); + DLOG << "out_C=" << out_C; + DLOG << "out_H=" << out_H; + DLOG << "out_W=" << out_W; + DLOG << "in_C=" << in_W; + DLOG << "default_work_size=" << default_work_size; + cl_int status; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(int), &out_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(int), &in_W); + CL_CHECK_ERRORS(status); + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), + NULL, default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); } } diff --git a/src/operators/op_param.h b/src/operators/op_param.h index a4d29a0f3a..5a2305876b 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1018,9 +1018,9 @@ class MultiClassNMSParam : public OpParam { score_threshold_ = GetAttr("score_threshold", attrs); } - const RType *InputBBoxes() const { return input_bboxes_; } + RType *InputBBoxes() const { return input_bboxes_; } - const RType *InputScores() const { return input_scores_; } + RType *InputScores() const { return input_scores_; } RType *Out() const { return out_; } -- GitLab