From 4c9e8384bd84e9b5f36cc5d78d38be9873d01ace Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Mon, 3 Feb 2020 10:33:33 +0800 Subject: [PATCH] =?UTF-8?q?[LITE][OPENCL]develop=20basic=20image=20depthwi?= =?UTF-8?q?seconv,passed=20loop=20test,test=E2=80=A6=20(#2788)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [LITE][OPENCL]develop basic image depthwiseconv,passed loop test,test=develop * [LITE][OPENCL]log to vlog(4),test=develop * [LITE][OPENCL]fix depthwise buffer conv kernel name ,test=develop --- .../image/depthwise_conv2d_basic_kernel.cl | 103 +++++ lite/kernels/opencl/CMakeLists.txt | 4 + .../depthwise_conv2d_basic_compute_test.cc | 389 ++++++++++++++++++ .../opencl/depthwise_conv2d_compute.cc | 210 +++++++++- 4 files changed, 705 insertions(+), 1 deletion(-) create mode 100755 lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl create mode 100644 lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl new file mode 100755 index 0000000000..70e429634f --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_basic_kernel.cl @@ -0,0 +1,103 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void depth_conv2d(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM + __read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int filter_width, + __private const int filter_height) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh); + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + const int batch_index = out_nh / output_height; + const int out_nh_in_one_batch = out_nh % output_height; + int2 stride_xy = (int2)(stride, stride); + int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch); + int2 in_pos_in_one_block = + ouput_pos_in_one_block * stride_xy + (int2)(offset, offset); +#ifdef BIASE_CH + CL_DTYPE4 output = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); +#elif defined(BIASE_ELE) + CL_DTYPE4 output = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos); +#else + CL_DTYPE4 output = 0.0f; +#endif + + int2 pos_in_input_block = + (int2)(out_c * input_width, batch_index * input_height); + int2 pos_in_filter_block = + (int2)(out_c * filter_width, batch_index * filter_height); + int filter_x = pos_in_filter_block.x; + int filter_y = pos_in_filter_block.y; + int input_x_base = pos_in_input_block.x + in_pos_in_one_block.x; + int input_y_base = pos_in_input_block.y + in_pos_in_one_block.y; + int2 align = {filter_width / 2, filter_height / 2}; + for (int fy = 0; fy < filter_height; ++fy) { + for (int fx = 0; fx < filter_width; ++fx) { + int x_off = fx - align.x; + int y_off = fy - align.y; + CL_DTYPE4 in = select( + READ_IMG_TYPE(CL_DTYPE_CHAR, + input, + sampler, + (int2)(input_x_base + x_off, input_y_base + y_off)), + (CL_DTYPE4)(0.0f), + (ushort4)((in_pos_in_one_block.x + x_off < 0 || + in_pos_in_one_block.y + y_off < 0 || + in_pos_in_one_block.x + x_off >= input_width || + in_pos_in_one_block.y + y_off >= input_height) + << 15)); + CL_DTYPE4 f = READ_IMG_TYPE( + CL_DTYPE_CHAR, filter, sampler, (int2)(filter_x + fx, filter_y + fy)); + output += in * f; + } + } +#ifdef BATCH_NORM + output = output * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output = activation_type4(output); + +#endif + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} \ No newline at end of file diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index f4d3254a7b..8bb49e428a 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -49,6 +49,10 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +lite_cc_test(test_depthwise_conv2d_basic_opencl SRCS depthwise_conv2d_basic_compute_test.cc + DEPS depthwise_conv2d_opencl op_registry program context cl_image_converter + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + #lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc # DEPS conv2d_1x1_opencl cl_image_converter op_registry program context # ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc b/lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc new file mode 100644 index 0000000000..96ee99e538 --- /dev/null +++ b/lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc @@ -0,0 +1,389 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +template +void depth_conv(const T* input_data, + const lite::DDim& input_dims, + const T* filter_data, + const lite::DDim& filter_dims, + T* output_data, + const lite::DDim& output_dims) { + int stride_h = STRIDE_H, stride_w = STRIDE_W; + + int64_t batches = input_dims[0]; + int64_t channels = input_dims[1]; + int64_t h = input_dims[2]; + int64_t w = input_dims[3]; + + int64_t num_output = output_dims[1]; + int64_t outh = output_dims[2]; + int64_t outw = output_dims[3]; + + int64_t filter_h = filter_dims[2]; + int64_t filter_w = filter_dims[3]; + + const int64_t in_batch_size = channels * h * w; + const int64_t out_batch_size = num_output * outh * outw; + + auto kernel_offset = std::unique_ptr(new int[filter_h * filter_w]); + { + int p = 0; + int offset = 0; + int gap = w - filter_w; + for (int i = 0; i < filter_h; i++) { + for (int j = 0; j < filter_w; j++) { + kernel_offset[p++] = offset; + offset += 1; + } + offset += gap; + } + } + + for (int b = 0; b < batches; b++) { + auto* input_batch_start = input_data + b * in_batch_size; + auto* output_batch_start = output_data + b * out_batch_size; + for (int p = 0; p < num_output; p++) { + float* output_ptr = output_batch_start + p * outh * outw; + const float* filter_ptr = filter_data + p * filter_h * filter_w; + const float* input_ptr = input_batch_start + p * h * w; + + for (int i = 0; i < outh; i++) { + for (int j = 0; j < outw; j++) { + float sum = 0; + const float* input_ch_start = + input_ptr + i * stride_h * w + j * stride_w; + + for (int fh = 0; fh < filter_h; ++fh) { + for (int fw = 0; fw < filter_w; ++fw) { + float val = input_ch_start[kernel_offset[fh * filter_w + fw]]; + float w = filter_ptr[fh * filter_w + fw]; + sum += val * w; + } + } + output_ptr[j] = sum; + } + + output_ptr += outw; + } + } + } +} +int ConvOutputSize(int input_size, + int filter_size, + int dilation, + int pad_left, + int pad_right, + int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = + (input_size + (pad_left + pad_right) - dkernel) / stride + 1; + + return output_size; +} + +TEST(depthwise_conv2d_basic, compute) { + // conv infos + // const int ksize = 1; + const int stride = 1; + const int pad = 0; + const int group = 1; + const int dilation = 1; + const int fc = 1; + const int batch_size = 1; + const int bias_flag = false; + const bool relu_flag = false; + +// int loop_cnt = 0; + +#ifdef LOOP_TEST + // for (int batch_size = 1; batch_size < 2; ++batch_size) { + for (int oc = 4; oc < 10; oc += 1) { // oc = ic + for (int fw = 3; fw < 10; fw += 2) { // fh = fw + for (int ih = fw; ih < 15; ih += 1) { // ih + for (int iw = fw; iw < 15; iw += 1) { // iw +#else + const int oc = 32; + const int ih = 112; + const int iw = 112; + const int fw = 5; + +#endif + + const int fb = oc; + const int ic = oc; + const int fh = fw; + + const int oh = ConvOutputSize(ih, fh, dilation, pad, pad, stride); + const int ow = ConvOutputSize(iw, fw, dilation, pad, pad, stride); + + VLOG(4) << "to get kernel ..."; + auto kernels = + KernelRegistry::Global().Create("depthwise_conv2d_basic", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + auto kernel = std::move(kernels.front()); + VLOG(4) << "created depthconv2d kernel"; + + VLOG(4) << "prepare kernel ------"; + + lite::Tensor input, filter, bias, output; + operators::ConvParam param; + param.x = &input; + param.filter = &filter; + param.output = &output; + if (bias_flag) { + param.bias = &bias; + } + param.fuse_relu = relu_flag; + + std::vector paddings = {pad, pad, pad, pad}; + std::vector dilations = {dilation, dilation}; + + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); + param.strides = std::vector{stride, stride}; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + std::unique_ptr depth_conv_context(new KernelContext); + context->As().CopySharedTo( + &(depth_conv_context->As())); + kernel->SetContext(std::move(depth_conv_context)); + + const DDim& input_dim = + lite::DDim{std::vector({batch_size, ic, ih, iw})}; + + const DDim& filter_dim = + lite::DDim{std::vector({fb, fc, fh, fw})}; + const DDim& out_dim = + lite::DDim{std::vector({batch_size, oc, oh, ow})}; + // element wise bias + const DDim& bias_dim = lite::DDim{std::vector({oc})}; + + param.x->Resize(input_dim); + param.filter->Resize(filter_dim); + param.output->Resize(out_dim); + if (bias_flag) { + param.bias->Resize(bias_dim); + } + + kernel->SetParam(param); + + size_t input_image_width = iw * ((ic + 3) / 4); + size_t input_image_height = ih * batch_size; + + size_t out_image_width = ow * ((oc + 3) / 4); + size_t out_image_height = oh * batch_size; + + size_t bias_image_width = ow * ((oc + 3) / 4); + size_t bias_image_height = oh * batch_size; + + size_t filter_image_width = fw * ((fb + 3) / 4); + size_t filter_image_height = fc * fh; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + + std::default_random_engine engine; + std::uniform_real_distribution gen(-5, 5); + + std::vector input_v(batch_size * ic * ih * iw); + std::vector filter_v(fb * fc * fh * fw); + std::vector output_v(batch_size * oc * ih * iw); + std::vector bias_v(oc); + + VLOG(4) << "gen input and filter ..."; + + for (auto& i : input_v) { + i = gen(engine); + } + for (auto& f : filter_v) { + f = gen(engine); + } + + VLOG(4) << "after gen input and filter ..."; + VLOG(4) << "input_v.size(): " << input_v.size(); + VLOG(4) << "filter_v.size(): " << filter_v.size(); + VLOG(4) << "output_v.size(): " << output_v.size(); + VLOG(4) << "bias_v.size(): " << bias_v.size(); + VLOG(4) << "input_dim.production(): " << input_dim.production(); + VLOG(4) << "filter_dim.production(): " << filter_dim.production(); + VLOG(4) << "out_dim.production(): " << out_dim.production(); + VLOG(4) << "bias_dim.production(): " << bias_dim.production(); + VLOG(4) << "4 * input_image_height * input_image_width: " + << 4 * input_image_height * input_image_width; + VLOG(4) << "4 * filter_image_width * filter_image_height: " + << 4 * filter_image_width * filter_image_height; + + CHECK(input_dim.production() == input_v.size()); + CHECK_LE(input_dim.production(), + 4 * input_image_height * input_image_width); + CHECK(filter_dim.production() == filter_v.size()); + CHECK_LE(filter_dim.production(), + 4 * filter_image_width * filter_image_height); + + paddle::lite::CLImageConverterDefault default_convertor; + VLOG(4) << "set mapped input ..."; + std::vector x_image_v(input_image_width * input_image_height * + 4); // 4 : RGBA + std::vector filter_image_v( + filter_image_width * filter_image_height * 4); // 4 : RGBA + std::vector bias_image_v(bias_image_width * bias_image_height * + 4); // 4 : RGBA + std::vector out_image_v(out_image_width * out_image_height * + 4); // 4 : RGBA + + default_convertor.NCHWToImage( + input_v.data(), x_image_v.data(), input_dim); + + VLOG(4) << "set mapped filter ..."; + paddle::lite::CLImageConverterNWBlock nw_convertor; + nw_convertor.NCHWToImage( + filter_v.data(), filter_image_v.data(), filter_dim); + + auto* input_image2d = input.mutable_data( + input_image_width, input_image_height, x_image_v.data()); + auto* filter_image2d = filter.mutable_data( + filter_image_width, filter_image_height, filter_image_v.data()); + + if (bias_flag) { + nw_convertor.NCHWToImage( + filter_v.data(), filter_image_v.data(), filter_dim); + + for (int i = 0; i < bias_dim.production(); ++i) { + bias_v[i] = static_cast(gen(engine)); + } + CLImageConverterFolder folder_convertor; + folder_convertor.NCHWToImage( + bias_v.data(), bias_image_v.data(), bias_dim); + auto* bias_data = bias.mutable_data( + bias_image_width, bias_image_height, bias_image_v.data()); + } + + VLOG(4) << "resize output ..."; + output.Resize(out_dim); + + // cpu conv basic calc + lite::Tensor out_ref; + out_ref.Resize(out_dim); + + VLOG(4) << "prepare kernel ready"; + + VLOG(4) << "kernel launch ..."; + kernel->Launch(); + VLOG(4) << "mutable output ..."; + auto* output_image2d = output.mutable_data( + out_image_width, out_image_height); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + + TargetWrapperCL::ImgcpySync(out_image_v.data(), + output.data(), + out_image_width, + out_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + DDim out_image_shape = + default_convertor.InitImageDimInfoWith(output.dims()); + + default_convertor.ImageToNCHW(out_image_v.data(), + output_v.data(), + out_image_shape, + output.dims()); + + // for (int j = 0; j < input_v.size(); j += 1) { + // VLOG(4) << "input_v input[" << j + // << "]: " << input_v.data()[j]; + // std::cout<< j << " " << input_v.data()[j] << std::endl; + // } + // std::cout << std::endl; + + // for (int j = 0; j < output_v.size(); j += 1) { + // VLOG(4) << "output_v output_v[" << j + // << "]:" << output_v.data()[j]; + // std::cout << j << " " << output_v.data()[j] << + // std::endl; + // } + + VLOG(4) << "mutable_data out_ref_data: "; + + // run cpu ref + auto* out_ref_data = out_ref.mutable_data(TARGET(kARM)); + + VLOG(4) << " conv_basic beigin ..... "; + depth_conv(input_v.data(), + input.dims(), + filter_v.data(), + filter.dims(), + out_ref_data, + out_dim); + VLOG(4) << " conv_basic end ..... "; + + VLOG(4) << " input_dim: " << input_dim; + VLOG(4) << " filter_dim: " << filter_dim; + const DDim& out_image_dims = lite::DDim{ + std::vector({static_cast(out_image_width), + static_cast(out_image_height)})}; + + for (int i = 0; i < out_dim.production(); i++) { + EXPECT_NEAR(output_v[i], out_ref_data[i], 1e-2); + if (abs(output_v[i] - out_ref_data[i]) > 1e-2) { + LOG(FATAL) << "error idx:" << i; + } + } + +#ifdef LOOP_TEST + } + } + } + } +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle +USE_LITE_KERNEL( + depthwise_conv2d_basic, kOpenCL, kFloat, kImageDefault, image2d); diff --git a/lite/kernels/opencl/depthwise_conv2d_compute.cc b/lite/kernels/opencl/depthwise_conv2d_compute.cc index 5d573c14f7..b90796f384 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute.cc @@ -13,6 +13,7 @@ // limitations under the License. #include + #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" @@ -114,7 +115,7 @@ class DepthwiseConv2dCompute } private: - std::string kernel_func_name_{"depthwise_conv2d_3x3"}; + std::string kernel_func_name_{"depthwise_conv2d"}; std::string build_options_{"-DCL_DTYPE=float"}; std::shared_ptr event_{new cl::Event}; }; @@ -341,6 +342,189 @@ class DepthwiseConv2d3x3s1ComputeFP16Image std::shared_ptr event_{new cl::Event}; }; +class DepthwiseConv2dBasicComputeFP32Image + : public KernelLite { + public: + using param_t = operators::ConvParam; + + std::string doc() const override { + return "DepthwiseConv2d basic using cl::Image2D/kImageDefault, kFloat32"; + } + + void PrepareForRun() override { + const auto& param = *param_.get_mutable(); + const bool has_bias = param.bias != nullptr; + const bool is_element_wise_bias = + has_bias && param.output->dims() == param.bias->dims(); + if (param.fuse_relu) { + build_options_ += " -DRELU"; + } + if (has_bias) { + build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH"; + } + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/depthwise_conv2d_basic_kernel.cl", + build_options_); + } + + void Run() override { + const auto& param = *param_.get_mutable(); + auto input_dims = param.x->dims(); + auto paddings = *param.paddings; + auto strides = param.strides; + auto* input_image = param.x->data(); + auto* filter_image = param.filter->data(); + auto filter_dims = param.filter->dims(); + auto output_dims = param.output->dims(); + + int input_width = input_dims[3]; + int input_height = input_dims[2]; + int output_width = output_dims[3]; + int output_height = output_dims[2]; + int filter_width = filter_dims[3]; + int filter_height = filter_dims[2]; + auto out_image_shape = InitImageDimInfoWith(output_dims); + auto* out_image = param.output->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + + const bool has_bias = param.bias != nullptr; + const bool is_element_wise_bias = + has_bias && param.output->dims() == param.bias->dims(); + int offset = static_cast(param.filter->dims()[2]) / 2 - + static_cast(paddings[0]); + + // calc input_c_block + auto input_image_shape = InitImageDimInfoWith(input_dims); + int input_c_block = input_image_shape["width"] / input_dims[3]; + int input_c = input_dims[1]; + auto dilations = *param.dilations; + + const std::vector& default_work_size = + DefaultWorkSize(output_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + + VLOG(4) << "============ depthwise conv2d params ============"; + VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," + << input_image_shape["height"]; + VLOG(4) << "input_c_block: " << input_c_block; + VLOG(4) << "input_c: " << input_c; + VLOG(4) << "input_image: " << input_image; + VLOG(4) << "filter_dims: " << filter_dims; + VLOG(4) << "filter_image: " << filter_image; + VLOG(4) << "output_dims: " << output_dims; + VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " + << out_image_shape["height"]; + VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; + VLOG(4) << "has bias: " << has_bias; + VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; + VLOG(4) << "strides: " << strides[0] << "," << strides[1]; + VLOG(4) << "offset: " << offset; + VLOG(4) << "dilations.size : " << dilations.size(); + VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + VLOG(4) << "default work size{c_block, w, nh}: " + << "{" << c_block << ", " << w << ", " << nh << "" + << "}"; + + CHECK_GE(dilations.size(), 2); + CHECK(dilations[0] == dilations[1]); + CHECK_GE(input_dims.size(), 4); + CHECK_GE(paddings.size(), 2); + CHECK(paddings[0] == paddings[1]); + CHECK_GE(strides.size(), 2); + CHECK(strides[0] == strides[1]); + + // handle bias use buffer for channel wise , use image for element wise + const cl::Buffer* bias_buf = nullptr; + const cl::Image2D* bias_image = nullptr; + if (has_bias) { + bias_image = param.bias->data(); + } + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + VLOG(4) << "kernel_key: " << kernel_key.str(); + VLOG(4) << "kernel ready ... " << kernel_key.str(); + VLOG(4) << "w: " << w; + + cl_int status; + int arg_idx = 0; + status = kernel.setArg(arg_idx, c_block); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, w); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, nh); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *input_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *filter_image); + CL_CHECK_FATAL(status); + if (has_bias) { + VLOG(4) << "set bias_image: "; + status = kernel.setArg(++arg_idx, *bias_image); + CL_CHECK_FATAL(status); + } + status = kernel.setArg(++arg_idx, *out_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, strides[0]); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, offset); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_c_block); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, dilations[0]); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_height); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size.data()[0]), + static_cast(default_work_size.data()[1]), + static_cast(default_work_size.data()[2])}; + + VLOG(4) << "out_image: " << out_image; + VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," + << global_work_size[1] << "," << global_work_size[2] << "}"; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_image, event_); + } + + private: + std::string kernel_func_name_{"depth_conv2d"}; + std::string build_options_{"-DCL_DTYPE_float"}; + std::shared_ptr event_{new cl::Event}; +}; } // namespace opencl } // namespace kernels } // namespace lite @@ -382,3 +566,27 @@ REGISTER_LITE_KERNEL( PRECISION(kFP16), DATALAYOUT(kImageDefault))}) .Finalize(); +REGISTER_LITE_KERNEL( + depthwise_conv2d_basic, + kOpenCL, + kFloat, + kImageDefault, + paddle::lite::kernels::opencl::DepthwiseConv2dBasicComputeFP32Image, + image2d) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageNW))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .Finalize(); -- GitLab