未验证 提交 ae8123d0 编写于 作者: X xiebaiyuan 提交者: GitHub

Merge pull request #1295 from yangfei963158659/develop

imp diffrent input size for superresolution and imp conv_bn_relu,dwco…
......@@ -281,6 +281,9 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
if (loddable_) {
ops[i]->InferShape();
}
// to Run
ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE
......
......@@ -43,15 +43,21 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
tensor->Resize(make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
// PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
// dim[0] = 1;
for (auto &d : dim) {
if (d < 0) {
d *= -1;
if (dim.size() == 0) {
auto tensor = var->GetMutable<LoDTensor>();
framework::DDim dDim = {0};
tensor->Resize(dDim);
} else {
for (auto &d : dim) {
if (d < 0) {
d *= -1;
}
}
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim(dim));
}
auto tensor = var->GetMutable<LoDTensor>();
tensor->Resize(make_ddim(dim));
}
} else {
// TODO(codeWorm): some.
......
......@@ -21,7 +21,13 @@ template <typename DeviceType, typename T>
void FeedOp<DeviceType, T>::InferShape() const {
auto out_dims = this->param_.Out()->dims();
out_dims[0] = this->param_.BatchSize();
this->param_.Out()->Resize(out_dims);
auto input_dims = this->param_.InputX()->dims();
DLOG << input_dims.size();
if (input_dims.size() == 4) {
this->param_.Out()->Resize(input_dims);
} else {
this->param_.Out()->Resize(out_dims);
}
}
} // namespace operators
......
......@@ -115,6 +115,7 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
template <typename P>
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
param.Output()->mutable_data<float>();
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define BATCH_NORM
#define RELU
#include "conv_kernel.inc.cl"
......@@ -20,7 +20,8 @@ __kernel void fetch(__private const int in_height,
__global float* out,
__private const int size_ch,
__private const int size_block,
__private const int size_batch) {
__private const int size_batch,
__private const int C) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
......@@ -35,9 +36,17 @@ __kernel void fetch(__private const int in_height,
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = convert_float(in.x);
out[index + size_ch] = convert_float(in.y);
if(C - 4 * in_c>=2){
out[index + size_ch] = convert_float(in.y);
}
if(C - 4 * in_c>=3){
out[index + size_ch * 2] = convert_float(in.z);
out[index + size_ch * 3] = convert_float(in.w);
}
if(C - 4 * in_c>=4){
out[index + size_ch * 3] = convert_float(in.w);
}
}
__kernel void fetch_2d(__private const int in_height,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void prior_box(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__global float *box_width,
__global float *box_height,
__write_only image2d_t output_image,
__private const float step_width,
__private const float step_height,
__private const float offset,
__private const int img_width,
__private const int img_height,
__private const int num_priors,
__private const int C){
const int out_c = get_global_id(0);
const int out_nh = get_global_id(1);
const int out_n = out_nh/num_priors;
const int out_h = out_nh%num_priors;
if (out_c >= global_size_dim0 ||out_nh >= global_size_dim2) {
return;
}
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * 4;
output_pos.y = out_nh;
float center_x0 = (offset + out_c * 4) * step_width;
float center_x1 = (offset + out_c * 4 + 1) * step_width;
float center_x2 = (offset + out_c * 4 + 2) * step_width;
float center_x3 = (offset + out_c * 4 + 3) * step_width;
float center_y = (out_n + offset) * step_height;
half4 output[4];
output[0].x = convert_half((center_x0 - box_width[out_h]) / img_width);
output[1].x = convert_half((center_y - box_height[out_h]) / img_height);
output[2].x = convert_half((center_x0 + box_width[out_h]) / img_width);
output[3].x = convert_half((center_y + box_height[out_h]) / img_height);
if(C - 4 * out_c>=2){
output[0].y = convert_half((center_x1 - box_width[out_h]) / img_width);
output[1].y = convert_half((center_y - box_height[out_h]) / img_height);
output[2].y = convert_half((center_x1 + box_width[out_h]) / img_width);
output[3].y = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].y = 0.0f;
output[1].y = 0.0f;
output[2].y = 0.0f;
output[3].y = 0.0f;
}
if(C - 4 * out_c>=3){
output[0].z = convert_half((center_x2 - box_width[out_h]) / img_width);
output[1].z = convert_half((center_y - box_height[out_h]) / img_height);
output[2].z = convert_half((center_x2 + box_width[out_h]) / img_width);
output[3].z = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].z = 0.0f;
output[1].z = 0.0f;
output[2].z = 0.0f;
output[3].z = 0.0f;
}
if(C - 4 * out_c>=4){
output[0].w = convert_half((center_x3 - box_width[out_h]) / img_width);
output[1].w = convert_half((center_y - box_height[out_h]) / img_height);
output[2].w = convert_half((center_x3 + box_width[out_h]) / img_width);
output[3].w = convert_half((center_y + box_height[out_h]) / img_height);
}else{
output[0].z = 0.0f;
output[1].z = 0.0f;
output[2].z = 0.0f;
output[3].z = 0.0f;
}
output[0] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[0]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[1] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[1]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[2] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[2]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
output[3] = min(max((half4)(0.0f, 0.0f, 0.0f, 0.0f), output[3]),(half4)(1.0f, 1.0f, 1.0f, 1.0f));
write_imageh(output_image, (int2)(output_pos.x + 1, output_pos.y), output[0]);
write_imageh(output_image, (int2)(output_pos.x + 2, output_pos.y), output[1]);
write_imageh(output_image, (int2)(output_pos.x + 3, output_pos.y), output[2]);
write_imageh(output_image, (int2)(output_pos.x + 4, output_pos.y), output[3]);
}
\ No newline at end of file
......@@ -68,10 +68,10 @@ void ConvAddKernel<GPU_CL, float>::Compute(
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
DLOG << "---yangfei30---";
DLOG << *param.Filter();
DLOG << param.Paddings();
auto biase = param.Bias()->GetCLImage();
param.Output()->InitEmptyImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue(),
param.Output()->dims());
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
......
......@@ -22,12 +22,185 @@ namespace operators {
template <>
bool ConvBNReluKernel<GPU_CL, float>::Init(
FusionConvBNReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
// for (int j = 0; j < C; ++j) {
// DLOG << " new scale - " << j << new_scale_ptr[j];
// }
//
// for (int j = 0; j < C; ++j) {
// DLOG << " new bias - " << j << new_bias_ptr[j];
// }
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - y bias: " << *(param->Bias());
//
// DLOG << " climage - new scale: " << *new_scale;
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// DLOG << " climage - new bias: " << *new_bias;
//
// DLOG << " climage - filter: " << *(param->Filter());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
param->Filter()->InitNImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_1x1", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu conv 1x1";
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == 3) {
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu depth_conv_3x3";
} else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) {
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu conv_3x3";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvBNReluKernel<GPU_CL, float>::Compute(
const FusionConvBNReluParam<GPU_CL> &param) {}
const FusionConvBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class ConvBNReluKernel<GPU_CL, float>;
} // namespace operators
......
......@@ -22,12 +22,151 @@ namespace operators {
template <>
bool DWConvBNReluKernel<GPU_CL, float>::Init(
FusionDWConvBNReluParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
const int C = mean->numel();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
framework::CLImage *new_scale = new framework::CLImage();
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
delete[](new_scale_ptr);
delete[](new_bias_ptr);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_bn_relu_kernel.cl");
DLOG << " conv bn relu depth_conv_3x3";
return true;
}
template <>
void DWConvBNReluKernel<GPU_CL, float>::Compute(
const FusionDWConvBNReluParam<GPU_CL> &param) {}
const FusionDWConvBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = reinterpret_cast<framework::CLImageConverterFolder *>(
param.Input()->Converter())
->GetCBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->dims()[3];
int input_height = param.Input()->dims()[2];
int output_width = param.Output()->dims()[3];
int output_height = param.Output()->dims()[2];
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class DWConvBNReluKernel<GPU_CL, float>;
} // namespace operators
......
......@@ -28,6 +28,8 @@ template <>
void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
cl_int status;
param.Out()->InitEmptyImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue(), param.Out()->dims());
auto output = param.Out();
const Tensor *input = param.InputX();
// DLOG << *input;
......
......@@ -27,8 +27,6 @@ bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
} else {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
}
auto *out = param->Out();
out->mutable_data<float>();
return true;
}
......@@ -39,7 +37,7 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto input = param.InputX()->GetCLImage();
auto *out = param.Out();
out->mutable_data<float>();
const auto &dim = param.InputX()->dims();
size_t new_dims[] = {1, 1, 1, 1};
......@@ -70,9 +68,11 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
int size_ch = in_height * in_width;
int size_block = size_ch * 4;
int size_batch = size_ch * C;
int out_c = new_dims[1];
clSetKernelArg(kernel, 4, sizeof(int), &size_ch);
clSetKernelArg(kernel, 5, sizeof(int), &size_block);
clSetKernelArg(kernel, 6, sizeof(int), &size_batch);
clSetKernelArg(kernel, 7, sizeof(int), &out_c);
}
// cl_event wait_event = param.InpdutX()->GetClEvent();
......@@ -93,6 +93,8 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
// << "ms" << std::endl;
memcpy(out->data<float>(), out_cl_tensor.Data<float>(), out->memory_size());
DLOG << *param.InputX();
DLOG << *out;
}
template class FetchKernel<GPU_CL, float>;
......
......@@ -15,18 +15,165 @@ limitations under the License. */
#ifdef PRIORBOX_OP
#include "operators/kernel/prior_box_kernel.h"
#include "framework/cl/cl_tensor.h"
namespace paddle_mobile {
namespace operators {
template <>
bool PriorBoxKernel<GPU_CL, float>::Init(PriorBoxParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("prior_box", "prior_box_kernel.cl");
return true;
}
template <>
void PriorBoxKernel<GPU_CL, float>::Compute(
const PriorBoxParam<GPU_CL> &param) {}
const PriorBoxParam<GPU_CL> &param) {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
const auto &input_image_dims = param.InputImage()->dims();
const auto &min_sizes = param.MinSizes();
const auto &max_sizes = param.MaxSizes();
const auto &variances = param.Variances();
const auto &input_aspect_ratio = param.AspectRatios();
const bool &flip = param.Flip();
const bool &clip = param.Clip();
const float &step_w = param.StepW();
const float &step_h = param.StepH();
const float &offset = param.Offset();
const int C = param.OutputBoxes()->dims()[1];
auto output_boxes = param.OutputBoxes()->GetCLImage();
auto output_variances = param.OutputVariances()->GetCLImage();
std::vector<float> aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
auto img_width = input_image_dims[3];
auto img_height = input_image_dims[2];
auto feature_width = input_dims[3];
auto feature_height = input_dims[2];
float step_width, step_height;
/// 300 / 19
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = aspect_ratios.size() * min_sizes.size();
if (!max_sizes.empty()) {
num_priors += max_sizes.size();
}
float *box_width = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * num_priors));
float *box_height = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * num_priors));
int idx = 0;
for (size_t s = 0; s < min_sizes.size(); ++s) {
auto min_size = min_sizes[s];
if (param.MinMaxAspectRatiosOrder()) {
box_width[idx] = box_height[idx] = min_size / 2.;
idx++;
if (max_sizes.size() > 0) {
auto max_size = max_sizes[s];
box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.;
idx++;
}
for (float ar : aspect_ratios) {
if (fabs(ar - 1.) < 1e-6) {
continue;
}
box_width[idx] = min_size * sqrt(ar) / 2.;
box_height[idx] = min_size / sqrt(ar) / 2.;
idx++;
}
} else {
for (float ar : aspect_ratios) {
box_width[idx] = min_size * sqrt(ar) / 2.;
box_height[idx] = min_size / sqrt(ar) / 2.;
idx++;
}
if (!max_sizes.empty()) {
auto max_size = max_sizes[s];
box_width[idx] = box_height[idx] = sqrt(min_size * max_size) / 2.;
idx++;
}
}
}
cl_int status;
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size =
this->cl_helper_.DefaultWorkSize(*param.OutputBoxes());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
std::vector<int64_t> box_shape({1, 1, 1, num_priors});
framework::DDim ddim = framework::make_ddim(box_shape);
framework::CLTensor box_width_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
box_width_cl_tensor.Resize(ddim);
cl_mem box_width_Buffer =
box_width_cl_tensor.mutable_with_data<float>(box_width);
framework::CLTensor box_height_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
box_height_cl_tensor.Resize(ddim);
cl_mem box_height_Buffer =
box_height_cl_tensor.mutable_with_data<float>(box_height);
DLOG << "c_block:" << c_block;
DLOG << "w:" << w;
DLOG << "nh:" << nh;
DLOG << "step_width:" << step_width;
DLOG << "step_height:" << step_height;
DLOG << "offset:" << offset;
DLOG << "img_width:" << img_width;
DLOG << "img_height:" << img_height;
DLOG << "num_priors:" << num_priors;
DLOG << "C:" << C;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &box_width_Buffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &box_height_Buffer);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &output_boxes);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(float), &step_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(float), &step_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(float), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &img_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &img_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &num_priors);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &C);
CL_CHECK_ERRORS(status);
size_t global_work_size[2] = {c_block, nh};
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
paddle_mobile::memory::Free(box_width);
paddle_mobile::memory::Free(box_height);
}
template class PriorBoxKernel<GPU_CL, float>;
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册