From 857b7116b3c27dd1d23b6d0f7cd9c66b03a36f36 Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Wed, 26 Feb 2020 00:06:25 +0800 Subject: [PATCH] =?UTF-8?q?[LITE][OPENCL][Image]=20fix=20issue=20in=20conc?= =?UTF-8?q?at=20and=20nearest=5Finterp=20=20thx=20for=E2=80=A6=20(#3011)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [LITE][OPENCL][Image] fix issue in concat and nearest_interp thx for chenj and ys,test=develop * [LITE][OPENCL][Image] fix issue in concat and nearest_interp thx for chenj and ys,test=develop --- .../opencl/cl_kernel/image/concat_kernel.cl | 139 +++++++++--------- lite/kernels/opencl/concat_image_compute.cc | 51 +++++-- .../opencl/nearest_interp_image_compute.cc | 25 ++-- 3 files changed, 120 insertions(+), 95 deletions(-) diff --git a/lite/backends/opencl/cl_kernel/image/concat_kernel.cl b/lite/backends/opencl/cl_kernel/image/concat_kernel.cl index c097a866ba..3b22ee253f 100644 --- a/lite/backends/opencl/cl_kernel/image/concat_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/concat_kernel.cl @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -15,9 +12,9 @@ limitations under the License. */ #include __kernel void concat2(__read_only image2d_t input0, - __read_only image2d_t input1, - __write_only image2d_t output, - int flag, int C_0, int out_C, int out_W, int width) { + __read_only image2d_t input1, + __write_only image2d_t output, + int flag, int C_0, int out_C, int out_W, int width) { const int out_w = get_global_id(0); // image_width cxw/4 const int out_c = get_global_id(1); // image_width cxw/4 const int out_nh = get_global_id(2); // image_height nxh @@ -32,51 +29,51 @@ __kernel void concat2(__read_only image2d_t input0, output_pos.y = out_nh; CL_DTYPE4 output_data; for (int i = 0; i < 4; i++) { - int c = out_c * 4 + i; - if (c >= out_C) { - break; - } - int c_in; - CL_DTYPE4 input_data; - if (c < C_0) { - c_in = c; - int2 input_pos; - input_pos.x = (c_in / 4) * out_W + out_w; - input_pos.y = out_nh; - input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, input_pos); - } else { - c_in = c - C_0; - int2 input_pos; - input_pos.x = (c_in / 4) * out_W + out_w; - input_pos.y = out_nh; - input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, sampler, input_pos); - } - int value_offset = c_in % 4; - float value; - if (value_offset == 0) { - value = input_data.x; - } else if (value_offset == 1) { - value = input_data.y; - } else if (value_offset == 2) { - value = input_data.z; - } else if (value_offset == 3) { - value = input_data.w; - } - if (i == 0) { - output_data.x = value; - } else if (i == 1) { - output_data.y = value; - } else if (i == 2) { - output_data.z = value; - } else if (i == 3) { - output_data.w = value; + int c = out_c * 4 + i; + if (c >= out_C) { + break; + } + int c_in; + CL_DTYPE4 input_data; + if (c < C_0) { + c_in = c; + int2 input_pos; + input_pos.x = (c_in / 4) * out_W + out_w; + input_pos.y = out_nh; + input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input0, sampler, input_pos); + } else { + c_in = c - C_0; + int2 input_pos; + input_pos.x = (c_in / 4) * out_W + out_w; + input_pos.y = out_nh; + input_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input1, sampler, input_pos); + } + int value_offset = c_in % 4; + CL_DTYPE value; + if (value_offset == 0) { + value = input_data.x; + } else if (value_offset == 1) { + value = input_data.y; + } else if (value_offset == 2) { + value = input_data.z; + } else if (value_offset == 3) { + value = input_data.w; + } + if (i == 0) { + output_data.x = value; + } else if (i == 1) { + output_data.y = value; + } else if (i == 2) { + output_data.z = value; + } else if (i == 3) { + output_data.w = value; } } WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, output_data); }else if (flag == 2){ // by height, width == n int2 input_pos; input_pos.x = out_c * out_W + out_w; - int h = out_nh / width; + int h = out_nh / width; CL_DTYPE4 input; if (h < C_0){ input_pos.y = out_nh; @@ -108,8 +105,8 @@ __kernel void concat2(__read_only image2d_t input0, } __kernel void concat_mul(__read_only image2d_t input, - __write_only image2d_t output, - int flag, int C_0, int out_C, int out_W, int in_W, int width) { + __write_only image2d_t output, + int flag, int C_0, int out_C, int out_W, int in_W, int width) { const int in_w = get_global_id(0); // image_width cxw/4 const int in_c = get_global_id(1); // image_width cxw/4 const int in_nh = get_global_id(2); // image_height nxh @@ -125,32 +122,32 @@ __kernel void concat_mul(__read_only image2d_t input, if (flag == 1){ // by channel CL_DTYPE4 output_data; for (int i = 0; i < 4; i++) { - int c_out = C_0 + in_c * 4 + i; - if (c_out >= out_C) { - break; - } - int2 output_pos; - output_pos.x = (c_out / 4) * in_W + in_w; - output_pos.y = in_nh; - float val; - if (i == 0) { - val = input_data.x; - } else if (i == 1) { - val = input_data.y; - } else if (i == 2) { - val = input_data.z; - } else if (i == 3) { - val = input_data.w; + int c_out = C_0 + in_c * 4 + i; + if (c_out >= out_C) { + break; + } + int2 output_pos; + output_pos.x = (c_out / 4) * in_W + in_w; + output_pos.y = in_nh; + CL_DTYPE val; + if (i == 0) { + val = input_data.x; + } else if (i == 1) { + val = input_data.y; + } else if (i == 2) { + val = input_data.z; + } else if (i == 3) { + val = input_data.w; } if (c_out % 4 == 0){ - output_data.x = val; + output_data.x = val; }else if (c_out % 4 == 1){ - output_data.y = val; - }else if (c_out % 4 == 2){ - output_data.z = val; - }else if (c_out % 4 == 3){ - output_data.w = val; - } + output_data.y = val; + }else if (c_out % 4 == 2){ + output_data.z = val; + }else if (c_out % 4 == 3){ + output_data.w = val; + } WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, output_data); } }else if (flag == 2){ // by height, width == n @@ -164,4 +161,4 @@ __kernel void concat_mul(__read_only image2d_t input, output_pos.x = in_c * out_W + (in_w + C_0); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, input_data); } -} +} \ No newline at end of file diff --git a/lite/kernels/opencl/concat_image_compute.cc b/lite/kernels/opencl/concat_image_compute.cc index 748b243742..ce1a5e6cea 100644 --- a/lite/kernels/opencl/concat_image_compute.cc +++ b/lite/kernels/opencl/concat_image_compute.cc @@ -96,7 +96,7 @@ class ConcatComputeImage : public KernelLite(); const auto& x_dims = param.output->dims(); auto image_shape = InitImageDimInfoWith(x_dims); - auto* out_buf = param.output->mutable_data( + auto* out_buf = param.output->mutable_data( image_shape["width"], image_shape["height"]); const auto& y_dims = param.output->dims(); // useless: check dim only @@ -107,21 +107,41 @@ class ConcatComputeImage : public KernelLitedims()[-1]; - auto global_work_size = cl::NDRange{ - static_cast(x_dims[-1]), - static_cast(image_shape["width"] / x_dims[-1]), - static_cast(image_shape["height"])}; + int width = inputs[0]->dims()[inputs[0]->dims().size() - 1]; + + LOG(INFO) << "concat 输入尺寸: "; + for (size_t i = 0; i < inputs.size(); i++) { + LOG(INFO) << "inputs [" << i << "]" + << "[" << inputs[i]->dims().size() << "D]:" + << " dims:" << inputs[i]->dims()[0] << " " + << inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " " + << inputs[i]->dims()[3]; + } + LOG(INFO) << "concat 输出尺寸: "; + LOG(INFO) << " out dims: " + << "[" << x_dims.size() << "D]:" << x_dims[0] << " " << x_dims[1] + << " " << x_dims[2] << " " << x_dims[3]; + LOG(INFO) << "axis_: " << axis_; + LOG(INFO) << "flag_: " << flag_; + auto global_work_size = + cl::NDRange{static_cast(x_dims[x_dims.size() - 1]), + static_cast(image_shape["width"] / + x_dims[x_dims.size() - 1]), + static_cast(image_shape["height"])}; VLOG(4) << TargetToStr(param.output->target()); VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " << image_shape["height"]; VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " - << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3] + << "x_dims[x_dims.size() - 1]" << x_dims[x_dims.size() - 1]; VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; - VLOG(4) << "width_: " << width_ << ", flag_: " << flag_; + LOG(INFO) << "width_: " << width_ << ", flag_: " << flag_; + VLOG(4) << "global_work_size: " << x_dims[x_dims.size() - 1] << " " + << (image_shape["width"] / x_dims[x_dims.size() - 1]) << " " + << (image_shape["height"]); auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - int out_w = x_dims[-1]; + int out_w = x_dims[x_dims.size() - 1]; int out_c = x_dims[1]; if (inputs.size() == 2) { auto* x_buf0 = inputs[0]->data(); @@ -159,13 +179,14 @@ class ConcatComputeImage : public KernelLitedims(); image_shape = InitImageDimInfoWith(in_dims); auto* x_buf = inputs[i]->data(); - auto in_w = in_dims[-1]; + int in_w = in_dims[in_dims.size() - 1]; VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " << image_shape["height"]; - global_work_size = cl::NDRange{ - static_cast(in_dims[-1]), - static_cast(image_shape["width"] / in_dims[-1]), - static_cast(image_shape["height"])}; + global_work_size = + cl::NDRange{static_cast(in_dims[in_dims.size() - 1]), + static_cast(image_shape["width"] / + in_dims[in_dims.size() - 1]), + static_cast(image_shape["height"])}; cl_int status = kernel.setArg(arg_idx, *x_buf); CL_CHECK_FATAL(status); status = kernel.setArg(++arg_idx, *out_buf); @@ -205,7 +226,7 @@ class ConcatComputeImage : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/nearest_interp_image_compute.cc b/lite/kernels/opencl/nearest_interp_image_compute.cc index ab7de85ae7..c22e38a8c2 100644 --- a/lite/kernels/opencl/nearest_interp_image_compute.cc +++ b/lite/kernels/opencl/nearest_interp_image_compute.cc @@ -45,15 +45,16 @@ class NearestInterpComputeImageDefault void Run() override { auto& param = *param_.get_mutable(); const auto& x_dims = param.X->dims(); + const auto& y_dims = param.Out->dims(); auto* x_buf = param.X->data(); // use half_t represents half float - auto image_shape = InitImageDimInfoWith(x_dims); + auto out_image_shape = InitImageDimInfoWith(y_dims); auto* out_buf = param.Out->mutable_data( // use half_t // represents half float - image_shape["width"], - image_shape["height"]); - const auto& y_dims = param.Out->dims(); // useless: check dim only + out_image_shape["width"], + out_image_shape["height"]); + float scale_h = y_dims[2] / x_dims[2]; float scale_w = y_dims[3] / x_dims[3]; int in_dims_h = x_dims[2]; @@ -87,16 +88,22 @@ class NearestInterpComputeImageDefault VLOG(4) << TargetToStr(param.X->target()); VLOG(4) << TargetToStr(param.Out->target()); - VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " - << image_shape["height"]; + VLOG(4) << "out_image_shape(w,h):" << out_image_shape["width"] << " " + << out_image_shape["height"]; VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + const std::vector& default_work_size = + DefaultWorkSize(y_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); auto global_work_size = - cl::NDRange{static_cast(image_shape["width"]), - static_cast(image_shape["height"])}; + cl::NDRange{static_cast(default_work_size.data()[0]), + static_cast(default_work_size.data()[1]), + static_cast(default_work_size.data()[2])}; status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, cl::NullRange, @@ -112,7 +119,7 @@ class NearestInterpComputeImageDefault private: std::string kernel_func_name_{"nearest_interp"}; - std::string build_options_{"-DCL_DTYPE_half"}; + std::string build_options_{" -DCL_DTYPE_half"}; std::shared_ptr event_{new cl::Event}; }; -- GitLab