提交 866f6be5 编写于 作者: H Huie 提交者: Yanzhan Yang

1.add density_prior_box for gpu. (#1877)

2.add flatten2 for gpu.
3.add concat 4 inputs size for gpu.
4.fix pool.
5.fix transpose2
test=develop
上级 a2b6978b
......@@ -13,12 +13,28 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FLATTEN2_OP
#include "operators/flatten2_op.h"
#include <operators/kernel/reshape_kernel.h>
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void Flatten2Op<DeviceType, T>::InferShape() const {}
void Flatten2Op<DeviceType, T>::InferShape() const {
const auto* input = this->param_.InputX();
auto* output = this->param_.Out();
auto input_x_dims = input->dims();
if (input->dims().size() == 4) {
PADDLE_MOBILE_ENFORCE(this->param_.Axis() == 1,
"flatten 2 only support axis == 1");
if (this->param_.Axis() == 1) {
std::vector<int> temp_output_dims(2);
temp_output_dims[0] = input->dims()[0];
temp_output_dims[1] =
input->dims()[1] * input->dims()[2] * input->dims()[3];
output->Resize(framework::make_ddim(temp_output_dims));
}
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -151,6 +151,87 @@ __kernel void concatByCWith3Inputs(__read_only image2d_t input_image_0,
write_imageh(output_image, output_pos, output_data);
}
__kernel void concatByCWith4Inputs(__read_only image2d_t input_image_0,
__read_only image2d_t input_image_1,
__read_only image2d_t input_image_2,
__read_only image2d_t input_image_3,
__private const int C_0,
__private const int C_1,
__private const int C_2,
__private const int C_3,
__write_only image2d_t output_image,
__private const int out_C,
__private const int out_W) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
half4 output_data;
for (int i = 0; i < 4; i++) {
int c = out_c * 4 + i;
if (c >= out_C) {
break;
}
int c_in;
half4 input_data;
if (c < C_0) {
c_in = c;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_0, sampler, input_pos);
} else if (c < C_0 + C_1) {
c_in = c - C_0;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_1, sampler, input_pos);
} else if (c < C_0 + C_1 + C_2) {
c_in = c - C_0 - C_1;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_2, sampler, input_pos);
}else if (c < C_0 + C_1 + C_2 + C_3){
c_in = c - C_0 - C_1 - C_2;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_3, sampler, input_pos);
}
int value_offset = c_in % 4;
float value;
if (value_offset == 0) {
value = input_data.x;
} else if (value_offset == 1) {
value = input_data.y;
} else if (value_offset == 2) {
value = input_data.z;
} else if (value_offset == 3) {
value = input_data.w;
}
if (i == 0) {
output_data.x = value;
} else if (i == 1) {
output_data.y = value;
} else if (i == 2) {
output_data.z = value;
} else if (i == 3) {
output_data.w = value;
}
}
write_imageh(output_image, output_pos, output_data);
}
__kernel void concatByH(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W,
......@@ -178,4 +259,33 @@ __kernel void concatByH(__read_only image2d_t input_image,
}
__kernel void concatByW(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int in_W,
__private const int pre_Width,
__private const int out_Width) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
int2 input_pos;
input_pos.x = in_c * in_W + in_w;
input_pos.y = in_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
input = read_imageh(input_image, sampler,input_pos);
int2 output_pos;
output_pos.x = input_pos.x + pre_Width + out_Width * in_c;
output_pos.y = input_pos.y;
write_imageh(output_image, output_pos, input);
}
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define MIN_VALUE -FLT_MAX
__kernel void density_prior_box(__write_only image2d_t output_boxes,
__write_only image2d_t output_variances,
__global float *densities,
__private const float step_h,
__private const float step_w,
__private float variances0,
__private float variances1,
__private float variances2,
__private float variances3,
__private float offset,
__private int den_and_fix_size,
__private int img_width,
__private int img_height,
__private int C,
__private int num_density,
__private int step_average,
__private int input_width,
__private int wid,
__private int fix_ratio_size
){
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int2 output_pos;
output_pos.x = out_c * 4 + out_w;
output_pos.y = out_nh;
half4 output;
half4 variances;
for (int c = 0; c < 4; c++) {
int idx = out_nh % num_density;
int input_h = out_nh / num_density;
int input_w = out_c * 4 + c;
int density_idx;
int density;
int ratio_idx;
int density_i;
int density_j;
int sum = 0;
int pre_sum = 0;
for (int i = 0; i < den_and_fix_size; i++) {
pre_sum = sum;
density = densities[i];
sum += density * density * fix_ratio_size;
if (idx < sum) {
density_idx = i;
break;
}
}
idx = idx - pre_sum;
ratio_idx = idx / (density * density);
idx = idx % (density * density);
density_i = idx / density;
density_j = idx % density;
half fixed_size = densities[den_and_fix_size + density_idx];
half ratio = densities[2 * den_and_fix_size + ratio_idx];
half box_width = fixed_size * ratio;
half box_height = fixed_size / ratio;
int shift = step_average / density;
half center_x;
half center_y;
center_x = (input_w + offset) * step_w;
center_x = center_x - step_average / 2.0 + shift / 2.0;
center_x = center_x + density_j * shift;
center_y = (input_h + offset) * step_h;
center_y = center_y - step_average / 2.0 + shift / 2.0;
center_y = center_y + density_i * shift;
half4 box;
box.x = (center_x - box_width / 2.0) / img_width;
box.y = (center_y - box_height / 2.0) / img_height;
box.z = (center_x + box_width / 2.0) / img_width;
box.w = (center_y + box_height / 2.0) / img_height;
box.x = max((float)box.x, 0.0);
box.y = max((float)box.y, 0.0);
box.z = min((float)box.z, 1.0);
box.w = min((float)box.w, 1.0);
half res;
half var;
if (out_w == 0) {
res = box.x;
var = convert_half(variances0);
} else if (out_w == 1) {
res = box.y;
var = convert_half(variances1);
} else if (out_w == 2) {
res = box.z;
var = convert_half(variances2);
} else if (out_w == 3) {
res = box.w;
var = convert_half(variances3);
}
variances.x = var;
variances.y = var;
variances.z = var;
variances.w = var;
if (c == 0) {
output.x = res;
} else if (c == 1) {
output.y = res;
} else if (c == 2) {
output.z = res;
} else if (c == 3) {
output.w = res;
}
}
write_imageh(output_boxes, (int2)(output_pos.x, output_pos.y), output);
__kernel void density_prior_box(){
write_imageh(output_variances, (int2)(output_pos.x, output_pos.y), variances);
}
\ No newline at end of file
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void flatten2(__read_only image2d_t input_img,
__write_only image2d_t output_img,
__private int out_width,
__private int in_width,
__private int in_height,
__private int in_C
){
__kernel void flatten2(){
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * out_width + out_w;
output_pos.y = out_nh;
int channel_size = in_width * in_height;
int in_c = output_pos.x / channel_size / 4;
int2 input_pos;
input_pos.x = (output_pos.x % in_width) + (in_c * in_width);
input_pos.y = (output_pos.x % channel_size) / in_width + out_nh * in_height;
half4 input_data = read_imageh(input_img, sampler, input_pos);
half4 output_data;
int in_c_offset = output_pos.x / channel_size % 4;
if(in_c_offset == 0){
output_data.x = input_data.x;
} else if(in_c_offset == 1){
output_data.x = input_data.y;
} else if(in_c_offset == 2){
output_data.x = input_data.z;
} else if(in_c_offset == 3){
output_data.x = input_data.w;
}
write_imageh(output_img, output_pos, output_data);
}
}
\ No newline at end of file
......@@ -69,23 +69,27 @@ __kernel void pool_avg(
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int start_h = max(out_h * stride_h - pad_top, 0);
int start_h = out_h * stride_h - pad_top;
int end_h = min(start_h + ksize_h, in_height);
start_h = max(start_h, 0);
int start_w = max(out_w * stride_w - pad_left, 0);
int start_w = out_w * stride_w - pad_left;
int end_w = min(start_w + ksize_w, in_width);
start_w = max(start_w, 0);
const int pos_in_x = out_c * in_width;
const int pos_in_y = out_n * in_height;
half4 sum = (half4)(0.0f);
int num = 0;
int num = 0 ;
for (int y = start_h; y < end_h; ++y) {
for (int x = start_w; x < end_w; ++x) {
sum += read_imageh(input, sampler, (int2)(pos_in_x + x, pos_in_y + y));
num++;
}
}
num = ksize_w * ksize_h;
half4 avg = sum / num;
const int pos_out_x = mad24(out_c, out_width, out_w);
write_imageh(output, (int2)(pos_out_x, out_nh), avg);
}
......@@ -22,12 +22,18 @@ namespace operators {
template <>
bool ConcatKernel<GPU_CL, float>::Init(ConcatParam<GPU_CL> *param) {
if (param->Out()->dims().size() < 4) {
this->cl_helper_.AddKernel("concatByH", "concat_kernel.cl");
if (param->Out()->dims().size() - param->axis_ == 1) {
this->cl_helper_.AddKernel("concatByW", "concat_kernel.cl");
} else {
this->cl_helper_.AddKernel("concatByH", "concat_kernel.cl");
}
} else if (param->Out()->dims().size() >= 4) {
if (param->Inputs().size() == 2) {
this->cl_helper_.AddKernel("concatByCWith2Inputs", "concat_kernel.cl");
} else if (param->Inputs().size() == 3) {
this->cl_helper_.AddKernel("concatByCWith3Inputs", "concat_kernel.cl");
} else if (param->Inputs().size() == 4) {
this->cl_helper_.AddKernel("concatByCWith4Inputs", "concat_kernel.cl");
} else {
return false;
}
......@@ -37,7 +43,6 @@ bool ConcatKernel<GPU_CL, float>::Init(ConcatParam<GPU_CL> *param) {
template <>
void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
DLOG << param.Out()->dims();
if (param.Out()->dims().size() < 4) {
auto kernel = this->cl_helper_.KernelAt(0);
auto inputs = param.Inputs();
......@@ -49,28 +54,57 @@ void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
out_W = param.Out()->dims()[1];
}
int out_H_Start = 0;
for (int i = 0; i < inputs.size(); i++) {
auto input_image = inputs[i]->GetCLImage();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[i]);
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &out_H_Start);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
if (param.Out()->dims().size() == 3) {
out_H_Start += inputs[i]->dims()[1];
} else if (param.Out()->dims().size() == 2) {
out_H_Start += inputs[i]->dims()[0];
if (param.Out()->dims().size() - param.axis_ == 1) {
for (int i = 0; i < inputs.size(); i++) {
int pre_Width = 0;
for (int k = 0; k < i; ++k) {
pre_Width += inputs[k]->dims()[inputs[k]->dims().size() - 1];
}
int in_w = inputs[i]->dims()[param.Out()->dims().size() - 2];
auto input_image = inputs[i]->GetCLImage();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[i]);
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &in_w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &pre_Width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
} else {
for (int i = 0; i < inputs.size(); i++) {
auto input_image = inputs[i]->GetCLImage();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*inputs[i]);
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &out_H_Start);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(),
NULL, default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
if (param.Out()->dims().size() == 3) {
out_H_Start += inputs[i]->dims()[1];
} else if (param.Out()->dims().size() == 2) {
out_H_Start += inputs[i]->dims()[0];
}
}
}
} else {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out());
......@@ -111,6 +145,32 @@ void ConcatKernel<GPU_CL, float>::Compute(const ConcatParam<GPU_CL> &param) {
status = clSetKernelArg(kernel0, 5, sizeof(int), &C_2);
CL_CHECK_ERRORS(status);
arg_offset = 6;
} else if (inputs.size() == 4) {
auto input_image_0 = inputs[0]->GetCLImage();
status = clSetKernelArg(kernel0, 0, sizeof(cl_mem), &input_image_0);
CL_CHECK_ERRORS(status);
auto input_image_1 = inputs[1]->GetCLImage();
status = clSetKernelArg(kernel0, 1, sizeof(cl_mem), &input_image_1);
CL_CHECK_ERRORS(status);
auto input_image_2 = inputs[2]->GetCLImage();
status = clSetKernelArg(kernel0, 2, sizeof(cl_mem), &input_image_2);
CL_CHECK_ERRORS(status);
auto input_image_3 = inputs[3]->GetCLImage();
status = clSetKernelArg(kernel0, 3, sizeof(cl_mem), &input_image_3);
CL_CHECK_ERRORS(status);
int C_0 = inputs[0]->dims()[1];
status = clSetKernelArg(kernel0, 4, sizeof(int), &C_0);
CL_CHECK_ERRORS(status);
int C_1 = inputs[1]->dims()[1];
status = clSetKernelArg(kernel0, 5, sizeof(int), &C_1);
CL_CHECK_ERRORS(status);
int C_2 = inputs[2]->dims()[1];
status = clSetKernelArg(kernel0, 6, sizeof(int), &C_2);
CL_CHECK_ERRORS(status);
int C_3 = inputs[3]->dims()[1];
status = clSetKernelArg(kernel0, 7, sizeof(int), &C_3);
CL_CHECK_ERRORS(status);
arg_offset = 8;
}
auto *output_image = param.Out()->GetCLImage();
status =
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef DENSITY_PRIORBOX_OP
#include <operators/kernel/prior_box_kernel.h>
#include "framework/cl/cl_tensor.h"
namespace paddle_mobile {
namespace operators {
......@@ -31,7 +31,124 @@ bool DensityPriorBoxKernel<GPU_CL, float>::Init(
template <>
void DensityPriorBoxKernel<GPU_CL, float>::Compute(
const paddle_mobile::operators::DensityPriorBoxParam<paddle_mobile::GPU_CL>
&param) {}
&param) {
auto kernel = this->cl_helper_.KernelAt(0);
const auto *input = param.Input();
const auto input_dims = input->dims();
const auto input_image_dims = param.InputImage()->dims();
auto output_boxes = param.OutputBoxes()->GetCLImage();
auto output_var = param.OutputVariances()->GetCLImage();
float step_w = param.StepW();
float step_h = param.StepH();
float offset = param.Offset();
vector<float> fixed_sizes = param.FixedSizes();
vector<float> fixed_ratios = param.FixedRatios();
vector<int> densities = param.Densities();
vector<float> variances = param.Variances();
// feature map
auto input_heigh = input_dims[2];
auto input_width = input_dims[3];
auto image_heigh = input_image_dims[2];
auto image_width = input_image_dims[3];
const int C = param.OutputBoxes()->dims()[1];
if (step_w == 0 || step_h == 0) {
step_h = static_cast<float>(image_heigh) / input_heigh;
step_w = static_cast<float>(image_width) / input_width;
}
int num_density = 0;
for (int l = 0; l < densities.size(); ++l) {
num_density += densities[l] * densities[l] * fixed_ratios.size();
}
param.OutputBoxes()->Resize({input_heigh, input_width, num_density, 4});
int step_average = static_cast<int>((step_w + step_h) * 0.5);
int densities_and_fixedsize_size = densities.size();
int fix_ratio_size = fixed_ratios.size();
auto default_work = this->cl_helper_.DefaultWorkSize(*param.OutputBoxes());
float *densities_data[densities.size() + fixed_sizes.size() + fix_ratio_size];
int status;
for (int i = 0; i < densities.size(); ++i) {
float density = densities[i];
densities_data[i] = &density;
}
for (int k = 0; k < fixed_sizes.size(); ++k) {
densities_data[k + densities.size()] = &fixed_sizes[k];
}
for (int j = 0; j < fixed_ratios.size(); ++j) {
float sqrt_ratios = sqrt(fixed_ratios[j]);
densities_data[j + densities.size() + fixed_sizes.size()] = &sqrt_ratios;
}
cl_mem densities_memobj = clCreateBuffer(
this->cl_helper_.CLContext(), CL_MEM_READ_WRITE,
sizeof(float) * (densities.size() * 2 + fix_ratio_size), NULL, &status);
status = clEnqueueWriteBuffer(
this->cl_helper_.CLCommandQueue(), densities_memobj, CL_FALSE, 0,
(densities.size() * 2 + fix_ratio_size) * sizeof(float), densities_data,
0, NULL, NULL);
CL_CHECK_ERRORS(status);
float variances0 = variances[0];
float variances1 = variances[1];
float variances2 = variances[2];
float variances3 = variances[3];
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &output_boxes);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_var);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &densities_memobj);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(float), &step_h);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(float), &step_w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(int), &variances0);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(int), &variances1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(int), &variances2);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &variances3);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(float), &offset);
CL_CHECK_ERRORS(status);
status =
clSetKernelArg(kernel, 10, sizeof(int), &densities_and_fixedsize_size);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &image_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &image_heigh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &num_density);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &step_average);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 16, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 17, sizeof(int), &default_work[0]);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 18, sizeof(int), &fix_ratio_size);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel,
default_work.size(), NULL,
default_work.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#ifdef FLATTEN2_OP
#include "operators/kernel/flatten2_kernel.h"
#include <operators/kernel/reshape_kernel.h>
namespace paddle_mobile {
namespace operators {
......@@ -29,7 +29,49 @@ bool Flatten2Kernel<GPU_CL, float>::Init(
template <>
void Flatten2Kernel<GPU_CL, float>::Compute(
const paddle_mobile::operators::FlattenParam<paddle_mobile::GPU_CL>
&param) {}
&param) {
auto kernel = this->cl_helper_.KernelAt(0);
const auto *input = param.InputX();
auto *output = param.Out();
auto input_image = input->GetCLImage();
auto output_image = output->GetCLImage();
int in_width = input->dims()[3];
int in_height = input->dims()[2];
int in_c = input->dims()[1];
int out_width = output->dims()[1];
DLOG << "flatten2 dims :" << output->dims() << " in: " << input->dims();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*output);
DLOG << "flatten2 work size :" << default_work_size.data()[0] << " "
<< default_work_size.data()[1] << " " << default_work_size.data()[2]
<< " " << default_work_size.size();
// const size_t work_size[2] = {output->ImageWidth(), output->ImageHeight()};
DLOG << "flatten2 work data :" << output->ImageWidth() << " "
<< output->ImageHeight();
DLOG << "flatten2 work data 4:" << out_width << " " << in_width << " "
<< in_height << " " << in_c;
int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &out_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &in_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(int), &in_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(int), &in_c);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -81,24 +81,31 @@ void ShuffleChannelCompute(const Transpose2Param<GPU_CL> &param,
auto output = param.Out();
Tensor *output_tensor = new Tensor();
output_tensor->Resize(input->dims());
framework::DDim out_dims(input->dims());
for (size_t i = 0; i < axis_size; i++) {
out_dims[i] = input->dims()[axis[i]];
}
output_tensor->Resize(out_dims);
output_tensor->mutable_data<float>();
Dtype *output_ptr = output_tensor->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const framework::DDim &in_dim = input->dims();
const framework::DDim &out_dim = output->dims();
size_t offset = 1;
for (int i = 2; i < axis.size(); ++i) {
for (int i = 3; i < axis.size(); ++i) {
offset *= in_dim[i];
}
#pragma omp parallel for collapse(2)
for (int c1 = 0; c1 < out_dim[0]; ++c1) {
for (int c2 = 0; c2 < out_dim[1]; ++c2) {
size_t out_offset = (c1 * out_dim[1] + c2) * offset;
size_t in_offset = (c2 * in_dim[1] + c1) * offset;
memcpy(output_ptr + out_offset, input_ptr + in_offset,
offset * sizeof(Dtype));
#pragma omp parallel for collapse(3)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int c1 = 0; c1 < out_dim[1]; ++c1) {
for (int c2 = 0; c2 < out_dim[2]; ++c2) {
size_t out_offset =
((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset;
size_t in_offset = ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset;
memcpy(output_ptr + out_offset, input_ptr + in_offset,
offset * sizeof(Dtype));
}
}
}
......@@ -110,6 +117,75 @@ void ShuffleChannelCompute(const Transpose2Param<GPU_CL> &param,
delete (output_tensor);
}
template <typename Dtype>
void Transpose2Compute(const Transpose2Param<GPU_CL> &param, cl_context context,
cl_command_queue commandQueue, cl_kernel kernel0,
cl_kernel kernel1) {
const std::vector<int> &axis = param.Axis();
auto input = param.InputX();
Tensor *input_tensor = new Tensor();
input_tensor->Resize(input->dims());
input_tensor->mutable_data<float>();
framework::CLImageToTensor(input, input_tensor, context, commandQueue,
kernel0);
const Dtype *input_ptr = input_tensor->data<Dtype>();
auto output = param.Out();
Tensor *output_tensor = new Tensor();
output_tensor->Resize(input->dims());
output_tensor->mutable_data<float>();
Dtype *output_ptr = output_tensor->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const framework::DDim &in_dim = input->dims();
const framework::DDim &out_dim = output->dims();
// precompute inverted output dim and strides
size_t rout_dim[6], strides[6];
int permute = axis.size(); // permute must >=2 && <= 6.
for (int i = 0; i < permute; ++i) {
int k = permute - 1 - i;
strides[k] = 1;
for (int j = axis[i] + 1; j < permute; ++j) {
strides[k] *= in_dim[j];
}
rout_dim[k] = out_dim[i];
}
// unroll the first 2 dimensions
int reamin_dim = 1;
for (int i = 2; i < out_dim.size(); ++i) {
reamin_dim *= out_dim[i];
}
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < out_dim[0]; ++batch) {
for (int j = 0; j < out_dim[1]; ++j) {
size_t offset = batch * strides[permute - 1] + j * strides[permute - 2];
Dtype *out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim;
int indics[4] = {0, 0, 0, 0};
for (int k = 0; k < reamin_dim; ++k) {
out_ptr[k] = input_ptr[offset];
indics[0] += 1;
offset += strides[0];
for (int p = 0; p < permute - 3; ++p) {
if (indics[p] == rout_dim[p]) {
indics[p + 1] += 1;
indics[p] = 0;
offset += strides[p + 1];
offset -= rout_dim[p] * strides[p];
} else {
break;
}
}
}
}
}
output->InitEmptyImage(context, commandQueue, output_tensor->dims());
framework::TensorToCLImage(output_tensor, output, context, commandQueue,
kernel1);
}
template <>
void Transpose2Kernel<GPU_CL, float>::Compute(
const Transpose2Param<GPU_CL> &param) {
......@@ -123,7 +199,9 @@ void Transpose2Kernel<GPU_CL, float>::Compute(
this->cl_helper_.CLCommandQueue(), kernel0,
kernel1);
} else {
PADDLE_MOBILE_THROW_EXCEPTION("axis not support");
Transpose2Compute<float>(param, this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel0,
kernel1);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册