diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index c7ef09ed5a1466a7396ec9c177eb3c48abd91ad7..80a990d5550ded3a5cc049fef366ba7e90938c00 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -281,6 +281,9 @@ std::shared_ptr Executor::Predict( clock_gettime(CLOCK_MONOTONIC, &ts); profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; #endif + if (loddable_) { + ops[i]->InferShape(); + } // to Run ops[i]->Run(); #ifdef PADDLE_MOBILE_PROFILE diff --git a/src/framework/loader.cpp b/src/framework/loader.cpp index 5587d0698fa2b9a04532deae618545d15ecd631f..3a8541e2139945f15605643092d2e9df4ecbefe7 100644 --- a/src/framework/loader.cpp +++ b/src/framework/loader.cpp @@ -43,15 +43,21 @@ void Loader::InitMemoryFromProgram( tensor->Resize(make_ddim(dim)); } else { auto dim = var_desc->Tensor_desc().Dims(); - PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0"); + // PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0"); // dim[0] = 1; - for (auto &d : dim) { - if (d < 0) { - d *= -1; + if (dim.size() == 0) { + auto tensor = var->GetMutable(); + framework::DDim dDim = {0}; + tensor->Resize(dDim); + } else { + for (auto &d : dim) { + if (d < 0) { + d *= -1; + } } + auto tensor = var->GetMutable(); + tensor->Resize(make_ddim(dim)); } - auto tensor = var->GetMutable(); - tensor->Resize(make_ddim(dim)); } } else { // TODO(codeWorm): some. diff --git a/src/operators/feed_op.cpp b/src/operators/feed_op.cpp index c3211b9fa9cc4b973788af4104c7ebe7bea2f54f..ac707d22696dd0a62902137607fb64c141341d77 100644 --- a/src/operators/feed_op.cpp +++ b/src/operators/feed_op.cpp @@ -21,7 +21,13 @@ template void FeedOp::InferShape() const { auto out_dims = this->param_.Out()->dims(); out_dims[0] = this->param_.BatchSize(); - this->param_.Out()->Resize(out_dims); + auto input_dims = this->param_.InputX()->dims(); + DLOG << input_dims.size(); + if (input_dims.size() == 4) { + this->param_.Out()->Resize(input_dims); + } else { + this->param_.Out()->Resize(out_dims); + } } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/conv_add_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_arm_func.h index d71bc235977236fbd0dd332df556ea4bd41eacf4..bacaa866b12957cfc300049c56bb9648fd360770 100644 --- a/src/operators/kernel/central-arm-func/conv_add_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_add_arm_func.h @@ -115,6 +115,7 @@ void ConvAddBasic(const FusionConvAddParam ¶m) { template void ConvAddCompute(const FusionConvAddParam ¶m) { + param.Output()->mutable_data(); if (param.Groups() == param.Input()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] && param.Filter()->dims()[2] == param.Filter()->dims()[3] && diff --git a/src/operators/kernel/cl/cl_kernel/conv_bn_relu_kernel.cl b/src/operators/kernel/cl/cl_kernel/conv_bn_relu_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..3aba8c2606d96812cdad72f570b27af60f632ae5 --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/conv_bn_relu_kernel.cl @@ -0,0 +1,18 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define BATCH_NORM +#define RELU + +#include "conv_kernel.inc.cl" diff --git a/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl b/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl index 64bb1845b0bd2c04c8761845b90dbed9e391a77b..f6014b732398cccd025a39cfb4a824b3154fcd66 100644 --- a/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/fetch_kernel.cl @@ -20,7 +20,8 @@ __kernel void fetch(__private const int in_height, __global float* out, __private const int size_ch, __private const int size_block, - __private const int size_batch) { + __private const int size_batch, + __private const int C) { const int in_c = get_global_id(0); const int in_w = get_global_id(1); const int in_nh = get_global_id(2); @@ -35,9 +36,17 @@ __kernel void fetch(__private const int in_height, const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w; out[index] = convert_float(in.x); - out[index + size_ch] = convert_float(in.y); + if(C - 4 * in_c>=2){ + out[index + size_ch] = convert_float(in.y); + } + if(C - 4 * in_c>=3){ out[index + size_ch * 2] = convert_float(in.z); - out[index + size_ch * 3] = convert_float(in.w); + } + + if(C - 4 * in_c>=4){ + out[index + size_ch * 3] = convert_float(in.w); + } + } __kernel void fetch_2d(__private const int in_height, diff --git a/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl b/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..053c7b3f06249f426c4a2203ba9c89362ded6a08 --- /dev/null +++ b/src/operators/kernel/cl/cl_kernel/prior_box_kernel.cl @@ -0,0 +1,100 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void prior_box(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __global float *box_width, + __global float *box_height, + __write_only image2d_t output_image, + __private const float step_width, + __private const float step_height, + __private const float offset, + __private const int img_width, + __private const int img_height, + __private const int num_priors, + __private const int C){ + + + const int out_c = get_global_id(0); + const int out_nh = get_global_id(1); + const int out_n = out_nh/num_priors; + const int out_h = out_nh%num_priors; + + if (out_c >= global_size_dim0 ||out_nh >= global_size_dim2) { + return; + } + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + int2 output_pos; + output_pos.x = out_c * 4; + output_pos.y = out_nh; + float center_x0 = (offset + out_c * 4) * step_width; + float center_x1 = (offset + out_c * 4 + 1) * step_width; + float center_x2 = (offset + out_c * 4 + 2) * step_width; + float center_x3 = (offset + out_c * 4 + 3) * step_width; + float center_y = (out_n + offset) * step_height; + + half4 output[4]; + output[0].x = convert_half((center_x0 - box_width[out_h]) / img_width); + output[1].x = convert_half((center_y - box_height[out_h]) / img_height); + output[2].x = convert_half((center_x0 + box_width[out_h]) / img_width); + output[3].x = convert_half((center_y + box_height[out_h]) / img_height); + + if(C - 4 * out_c>=2){ + output[0].y = convert_half((center_x1 - box_width[out_h]) / img_width); + output[1].y = convert_half((center_y - box_height[out_h]) / img_height); + output[2].y = convert_half((center_x1 + box_width[out_h]) / img_width); + output[3].y = convert_half((center_y + box_height[out_h]) / img_height); + }else{ + output[0].y = 0.0f; + output[1].y = 0.0f; + output[2].y = 0.0f; + output[3].y = 0.0f; + } + if(C - 4 * out_c>=3){ + output[0].z = convert_half((center_x2 - box_width[out_h]) / img_width); + output[1].z = convert_half((center_y - box_height[out_h]) / img_height); + output[2].z = convert_half((center_x2 + box_width[out_h]) / img_width); + output[3].z = convert_half((center_y + box_height[out_h]) / img_height); + }else{ + output[0].z = 0.0f; + output[1].z = 0.0f; + output[2].z = 0.0f; + output[3].z = 0.0f; + } + if(C - 4 * out_c>=4){ + output[0].w = convert_half((center_x3 - box_width[out_h]) / img_width); + output[1].w = convert_half((center_y - box_height[out_h]) / img_height); + output[2].w = convert_half((center_x3 + box_width[out_h]) / img_width); + output[3].w = convert_half((center_y + box_height[out_h]) / img_height); + }else{ + output[0].z = 0.0f; + output[1].z = 0.0f; + output[2].z = 0.0f; + output[3].z = 0.0f; + } + output[0] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[0]),(half4)(1.0f, 1.0f, 1.0f, 1.0f)); + output[1] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[1]),(half4)(1.0f, 1.0f, 1.0f, 1.0f)); + output[2] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[2]),(half4)(1.0f, 1.0f, 1.0f, 1.0f)); + output[3] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[3]),(half4)(1.0f, 1.0f, 1.0f, 1.0f)); + write_imageh(output_image, (int2)(output_pos.x + 1, output_pos.y), output[0]); + write_imageh(output_image, (int2)(output_pos.x + 2, output_pos.y), output[1]); + write_imageh(output_image, (int2)(output_pos.x + 3, output_pos.y), output[2]); + write_imageh(output_image, (int2)(output_pos.x + 4, output_pos.y), output[3]); + +} \ No newline at end of file diff --git a/src/operators/kernel/cl/conv_add_kernel.cpp b/src/operators/kernel/cl/conv_add_kernel.cpp index 7e30c6d31db645fb5d18bf70ef5b6876a5f683da..3292cc7ccd2febc4d1e5b8f5e4991f8348b25196 100644 --- a/src/operators/kernel/cl/conv_add_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_kernel.cpp @@ -68,10 +68,10 @@ void ConvAddKernel::Compute( int nh = default_work_size[2]; auto input = param.Input()->GetCLImage(); auto filter = param.Filter()->GetCLImage(); - DLOG << "---yangfei30---"; - DLOG << *param.Filter(); - DLOG << param.Paddings(); auto biase = param.Bias()->GetCLImage(); + param.Output()->InitEmptyImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue(), + param.Output()->dims()); auto output = param.Output()->GetCLImage(); int stride = param.Strides()[0]; int offset = param.Offset(); diff --git a/src/operators/kernel/cl/conv_bn_relu_kernel.cpp b/src/operators/kernel/cl/conv_bn_relu_kernel.cpp index be2da60de0bf429656978d696f8c0067b74559ad..945f84358ff2f5a61ece0c8f96d39b3c10c01c1e 100644 --- a/src/operators/kernel/cl/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/cl/conv_bn_relu_kernel.cpp @@ -22,12 +22,185 @@ namespace operators { template <> bool ConvBNReluKernel::Init( FusionConvBNReluParam *param) { + PADDLE_MOBILE_ENFORCE( + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Paddings()[0] == param->Paddings()[1], + "need equal"); + const framework::CLImage *mean = param->InputMean(); + const framework::CLImage *variance = param->InputVariance(); + const framework::CLImage *scale = param->InputScale(); + const framework::CLImage *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + const int C = mean->numel(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + float *new_scale_ptr = new float[C]; + float *new_bias_ptr = new float[C]; + + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + + framework::CLImage *new_scale = new framework::CLImage(); + + // for (int j = 0; j < C; ++j) { + // DLOG << " new scale - " << j << new_scale_ptr[j]; + // } + // + // for (int j = 0; j < C; ++j) { + // DLOG << " new bias - " << j << new_bias_ptr[j]; + // } + + new_scale->SetTensorData(new_scale_ptr, variance->dims()); + new_scale->InitCLImage(this->cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + + // DLOG << " climage - y bias: " << *(param->Bias()); + // + // DLOG << " climage - new scale: " << *new_scale; + + framework::CLImage *new_bias = new framework::CLImage(); + + new_bias->SetTensorData(new_bias_ptr, variance->dims()); + new_bias->InitCLImage(this->cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + + // DLOG << " climage - new bias: " << *new_bias; + // + // DLOG << " climage - filter: " << *(param->Filter()); + + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + + delete[](new_scale_ptr); + delete[](new_bias_ptr); + + PADDLE_MOBILE_ENFORCE( + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Paddings()[0] == param->Paddings()[1], + "need equal"); + + int offset = static_cast(param->Filter()->dims()[2]) / 2 - + static_cast(param->Paddings()[1]); + + param->SetOffset(offset); + + if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) { + param->Filter()->InitNImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("conv_1x1", "conv_bn_relu_kernel.cl"); + DLOG << " conv bn relu conv 1x1"; + } else if (param->Filter()->dims()[1] == 1 && + param->Input()->dims()[1] == param->Output()->dims()[1] && + param->Filter()->dims()[2] == 3) { + param->Filter()->InitDWImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl"); + DLOG << " conv bn relu depth_conv_3x3"; + + } else if (param->Filter()->dims()[2] == 3 && + param->Filter()->dims()[3] == 3) { + param->Filter()->InitCLImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + + this->cl_helper_.AddKernel("conv_3x3", "conv_bn_relu_kernel.cl"); + DLOG << " conv bn relu conv_3x3"; + } else { + PADDLE_MOBILE_THROW_EXCEPTION(" not support "); + } return true; } template <> void ConvBNReluKernel::Compute( - const FusionConvBNReluParam ¶m) {} + const FusionConvBNReluParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output()); + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + auto input = param.Input()->GetCLImage(); + auto filter = param.Filter()->GetCLImage(); + auto new_scale = param.NewScale()->GetCLImage(); + auto new_bias = param.NewBias()->GetCLImage(); + auto output = param.Output()->GetCLImage(); + int stride = param.Strides()[0]; + int offset = param.Offset(); + int input_c = reinterpret_cast( + param.Input()->Converter()) + ->GetCBlock(); + int dilation = param.Dilations()[0]; + int input_width = param.Input()->dims()[3]; + int input_height = param.Input()->dims()[2]; + int output_width = param.Output()->dims()[3]; + int output_height = param.Output()->dims()[2]; + + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 8, sizeof(int), &stride); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 9, sizeof(int), &offset); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 10, sizeof(int), &input_c); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 11, sizeof(int), &dilation); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 12, sizeof(int), &input_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 13, sizeof(int), &input_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 14, sizeof(int), &output_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 15, sizeof(int), &output_height); + CL_CHECK_ERRORS(status); + + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); +} template class ConvBNReluKernel; } // namespace operators diff --git a/src/operators/kernel/cl/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/cl/dwconv_bn_relu_kernel.cpp index 48e08d69ec22f128885bb7aa9165e0898ca67b7a..31cb654d75c9c83e132a52d740b0b2db1c3a2ccf 100644 --- a/src/operators/kernel/cl/dwconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/cl/dwconv_bn_relu_kernel.cpp @@ -22,12 +22,151 @@ namespace operators { template <> bool DWConvBNReluKernel::Init( FusionDWConvBNReluParam *param) { + PADDLE_MOBILE_ENFORCE( + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Paddings()[0] == param->Paddings()[1], + "need equal"); + const framework::CLImage *mean = param->InputMean(); + const framework::CLImage *variance = param->InputVariance(); + const framework::CLImage *scale = param->InputScale(); + const framework::CLImage *bias = param->InputBias(); + const float epsilon = param->Epsilon(); + + const int C = mean->numel(); + + auto mean_ptr = mean->data(); + auto variance_ptr = variance->data(); + auto scale_ptr = scale->data(); + auto bias_ptr = bias->data(); + + float inv_std_ptr[C]; + for (int i = 0; i < C; i++) { + inv_std_ptr[i] = + 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); + } + float *new_scale_ptr = new float[C]; + float *new_bias_ptr = new float[C]; + + for (int i = 0; i < C; i++) { + new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; + new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; + } + + framework::CLImage *new_scale = new framework::CLImage(); + + new_scale->SetTensorData(new_scale_ptr, variance->dims()); + new_scale->InitCLImage(this->cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + + framework::CLImage *new_bias = new framework::CLImage(); + + new_bias->SetTensorData(new_bias_ptr, variance->dims()); + new_bias->InitCLImage(this->cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + + param->SetNewScale(new_scale); + param->SetNewBias(new_bias); + + delete[](new_scale_ptr); + delete[](new_bias_ptr); + + PADDLE_MOBILE_ENFORCE( + param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Paddings()[0] == param->Paddings()[1], + "need equal"); + + int offset = static_cast(param->Filter()->dims()[2]) / 2 - + static_cast(param->Paddings()[1]); + + param->SetOffset(offset); + + param->Filter()->InitDWImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); + this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl"); + DLOG << " conv bn relu depth_conv_3x3"; + return true; } template <> void DWConvBNReluKernel::Compute( - const FusionDWConvBNReluParam ¶m) {} + const FusionDWConvBNReluParam ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output()); + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + auto input = param.Input()->GetCLImage(); + auto filter = param.Filter()->GetCLImage(); + auto new_scale = param.NewScale()->GetCLImage(); + auto new_bias = param.NewBias()->GetCLImage(); + auto output = param.Output()->GetCLImage(); + int stride = param.Strides()[0]; + int offset = param.Offset(); + int input_c = reinterpret_cast( + param.Input()->Converter()) + ->GetCBlock(); + int dilation = param.Dilations()[0]; + int input_width = param.Input()->dims()[3]; + int input_height = param.Input()->dims()[2]; + int output_width = param.Output()->dims()[3]; + int output_height = param.Output()->dims()[2]; + + cl_int status; + + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 8, sizeof(int), &stride); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 9, sizeof(int), &offset); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 10, sizeof(int), &input_c); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 11, sizeof(int), &dilation); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 12, sizeof(int), &input_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 13, sizeof(int), &input_height); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 14, sizeof(int), &output_width); + CL_CHECK_ERRORS(status); + + status = clSetKernelArg(kernel, 15, sizeof(int), &output_height); + CL_CHECK_ERRORS(status); + + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); +} template class DWConvBNReluKernel; } // namespace operators diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index 78f04357a23c70595595cc24489fd96e994162fb..941a6cb815541d1eca30ccc193161838ce28da80 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -28,6 +28,8 @@ template <> void FeedKernel::Compute(const FeedParam ¶m) { auto kernel = this->cl_helper_.KernelAt(0); cl_int status; + param.Out()->InitEmptyImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue(), param.Out()->dims()); auto output = param.Out(); const Tensor *input = param.InputX(); // DLOG << *input; diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp index 31c1d4179cbdfc8145d90bee2353be821e65b40b..8ea0b3ad3d33f0352fba7697fc08ad7a2039e6ab 100644 --- a/src/operators/kernel/cl/fetch_kernel.cpp +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -27,8 +27,6 @@ bool FetchKernel::Init(FetchParam *param) { } else { this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); } - auto *out = param->Out(); - out->mutable_data(); return true; } @@ -39,7 +37,7 @@ void FetchKernel::Compute(const FetchParam ¶m) { auto input = param.InputX()->GetCLImage(); auto *out = param.Out(); - + out->mutable_data(); const auto &dim = param.InputX()->dims(); size_t new_dims[] = {1, 1, 1, 1}; @@ -70,9 +68,11 @@ void FetchKernel::Compute(const FetchParam ¶m) { int size_ch = in_height * in_width; int size_block = size_ch * 4; int size_batch = size_ch * C; + int out_c = new_dims[1]; clSetKernelArg(kernel, 4, sizeof(int), &size_ch); clSetKernelArg(kernel, 5, sizeof(int), &size_block); clSetKernelArg(kernel, 6, sizeof(int), &size_batch); + clSetKernelArg(kernel, 7, sizeof(int), &out_c); } // cl_event wait_event = param.InpdutX()->GetClEvent(); @@ -93,6 +93,8 @@ void FetchKernel::Compute(const FetchParam ¶m) { // << "ms" << std::endl; memcpy(out->data(), out_cl_tensor.Data(), out->memory_size()); + DLOG << *param.InputX(); + DLOG << *out; } template class FetchKernel; diff --git a/src/operators/kernel/cl/prior_box_kernel.cpp b/src/operators/kernel/cl/prior_box_kernel.cpp index f8d8c51fcaca214e6248b83d9e135670fa28fe94..1f8843787bc68c8be681e4c2a79053714b76dc4e 100644 --- a/src/operators/kernel/cl/prior_box_kernel.cpp +++ b/src/operators/kernel/cl/prior_box_kernel.cpp @@ -15,18 +15,165 @@ limitations under the License. */ #ifdef PRIORBOX_OP #include "operators/kernel/prior_box_kernel.h" - +#include "framework/cl/cl_tensor.h" namespace paddle_mobile { namespace operators { template <> bool PriorBoxKernel::Init(PriorBoxParam *param) { + this->cl_helper_.AddKernel("prior_box", "prior_box_kernel.cl"); return true; } template <> void PriorBoxKernel::Compute( - const PriorBoxParam ¶m) {} + const PriorBoxParam ¶m) { + const auto *input_ = param.Input(); + const auto &input_dims = input_->dims(); + + const auto &input_image_dims = param.InputImage()->dims(); + + const auto &min_sizes = param.MinSizes(); + const auto &max_sizes = param.MaxSizes(); + const auto &variances = param.Variances(); + const auto &input_aspect_ratio = param.AspectRatios(); + const bool &flip = param.Flip(); + const bool &clip = param.Clip(); + const float &step_w = param.StepW(); + const float &step_h = param.StepH(); + const float &offset = param.Offset(); + const int C = param.OutputBoxes()->dims()[1]; + + auto output_boxes = param.OutputBoxes()->GetCLImage(); + auto output_variances = param.OutputVariances()->GetCLImage(); + + std::vector aspect_ratios; + ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios); + + auto img_width = input_image_dims[3]; + auto img_height = input_image_dims[2]; + + auto feature_width = input_dims[3]; + auto feature_height = input_dims[2]; + + float step_width, step_height; + /// 300 / 19 + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + + int num_priors = aspect_ratios.size() * min_sizes.size(); + if (!max_sizes.empty()) { + num_priors += max_sizes.size(); + } + + float *box_width = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * num_priors)); + float *box_height = static_cast( + paddle_mobile::memory::Alloc(sizeof(float) * num_priors)); + int idx = 0; + for (size_t s = 0; s < min_sizes.size(); ++s) { + auto min_size = min_sizes[s]; + if (param.MinMaxAspectRatiosOrder()) { + box_width[idx] = box_height[idx] = min_size / 2.; + idx++; + if (max_sizes.size() > 0) { + auto max_size = max_sizes[s]; + box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.; + idx++; + } + for (float ar : aspect_ratios) { + if (fabs(ar - 1.) < 1e-6) { + continue; + } + box_width[idx] = min_size * sqrt(ar) / 2.; + box_height[idx] = min_size / sqrt(ar) / 2.; + idx++; + } + + } else { + for (float ar : aspect_ratios) { + box_width[idx] = min_size * sqrt(ar) / 2.; + box_height[idx] = min_size / sqrt(ar) / 2.; + idx++; + } + if (!max_sizes.empty()) { + auto max_size = max_sizes[s]; + box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.; + idx++; + } + } + } + cl_int status; + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = + this->cl_helper_.DefaultWorkSize(*param.OutputBoxes()); + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + + std::vector box_shape({1, 1, 1, num_priors}); + framework::DDim ddim = framework::make_ddim(box_shape); + + framework::CLTensor box_width_cl_tensor(this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue()); + box_width_cl_tensor.Resize(ddim); + cl_mem box_width_Buffer = + box_width_cl_tensor.mutable_with_data(box_width); + + framework::CLTensor box_height_cl_tensor(this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue()); + box_height_cl_tensor.Resize(ddim); + cl_mem box_height_Buffer = + box_height_cl_tensor.mutable_with_data(box_height); + + DLOG << "c_block:" << c_block; + DLOG << "w:" << w; + DLOG << "nh:" << nh; + DLOG << "step_width:" << step_width; + DLOG << "step_height:" << step_height; + DLOG << "offset:" << offset; + DLOG << "img_width:" << img_width; + DLOG << "img_height:" << img_height; + DLOG << "num_priors:" << num_priors; + DLOG << "C:" << C; + status = clSetKernelArg(kernel, 0, sizeof(int), &c_block); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(int), &w); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &nh); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &box_width_Buffer); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &box_height_Buffer); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output_boxes); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(float), &step_width); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(float), &step_height); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 8, sizeof(float), &offset); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 9, sizeof(int), &img_width); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 10, sizeof(int), &img_height); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 11, sizeof(int), &num_priors); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 12, sizeof(int), &C); + CL_CHECK_ERRORS(status); + size_t global_work_size[2] = {c_block, nh}; + status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, + NULL, global_work_size, NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + paddle_mobile::memory::Free(box_width); + paddle_mobile::memory::Free(box_height); +} template class PriorBoxKernel; } // namespace operators