diff --git a/lite/api/mobilenetv1_test.cc b/lite/api/mobilenetv1_test.cc index 79f9bea762e099b249f597dddb7df790361edc2a..085f7f3ad7101a59b8035ac3a8ad8a1e602fb102 100644 --- a/lite/api/mobilenetv1_test.cc +++ b/lite/api/mobilenetv1_test.cc @@ -123,10 +123,10 @@ TEST(MobileNetV1, test_arm) { #ifdef LITE_WITH_OPENCL TEST(MobileNetV1, test_opencl) { std::vector valid_places({ - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)}, - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNHWC)}, + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)}, Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}, - Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)}, + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)}, + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)}, TARGET(kARM), // enable kARM CPU kernel when no opencl kernel }); diff --git a/lite/api/opt.cc b/lite/api/opt.cc index c172169e59ec074b81a07e4fc96cd0363c50a10a..2435a29878f5eafa4abffe045ce55a1f2606acd3 100644 --- a/lite/api/opt.cc +++ b/lite/api/opt.cc @@ -89,13 +89,13 @@ std::vector ParserValidPlaces() { valid_places.emplace_back(TARGET(kARM)); } else if (target_repr == "opencl") { valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNCHW)}); - valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kNHWC)}); + Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)}); valid_places.emplace_back( Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)}); valid_places.emplace_back( - Place{TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)}); + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kImageDefault)}); + valid_places.emplace_back( + Place{TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)}); valid_places.emplace_back( TARGET(kARM)); // enable kARM CPU kernel when no opencl kernel } else if (target_repr == "x86") { diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl index 1e3586b7fde8d79fe49327185c623ac613cd080d..14086dcd16bd1a8770f444bdcd0b6bea78e23b7e 100755 --- a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl @@ -142,7 +142,7 @@ __kernel void depth_conv2d_3x3(__private const int global_size_dim0, #endif #ifdef RELU - output = activation(output); + output = activation_type4(output); #endif @@ -309,8 +309,8 @@ __kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk, #endif #ifdef RELU - output[0] = activation(output[0]); - output[1] = activation(output[1]); + output[0] = activation_type4(output[0]); + output[1] = activation_type4(output[1]); #endif WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, (int2)(ou_x, ou_nh_id), output[0]); diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index e81fdf307e94fbb6593962052b911c34a944777a..f87b37fc62343b00aedd92fc7c30de3ea42c3c9d 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -16,7 +16,6 @@ add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${te add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(sigmoid_opencl OPENCL basic SRCS sigmoid_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) -#add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps}) add_kernel(reshape_opencl OPENCL basic SRCS reshape_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps} cl_image_converter) add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) @@ -62,14 +61,10 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc DEPS depthwise_conv2d_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -lite_cc_test(test_depthwise_conv2d_basic_opencl SRCS depthwise_conv2d_basic_compute_test.cc - DEPS depthwise_conv2d_opencl op_registry program context +lite_cc_test(test_depthwise_conv2d_image2d_opencl SRCS depthwise_conv2d_image2d_compute_test.cc + DEPS conv_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) -#lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc -# DEPS conv2d_1x1_opencl op_registry program context -# ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) - lite_cc_test(test_reshape_opencl SRCS reshape_compute_test.cc DEPS reshape_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/concat_compute.cc b/lite/kernels/opencl/concat_compute.cc index 0f25439ed00a9ff579bbd59a543dba3c8c3b090b..c57602e39aea27250eabfcf7a0570d80d7ff3dc4 100644 --- a/lite/kernels/opencl/concat_compute.cc +++ b/lite/kernels/opencl/concat_compute.cc @@ -356,17 +356,17 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kImageDefault))}) .Finalize(); -REGISTER_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, Concat_buffer, def) - .BindInput("X", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNCHW))}) - .BindInput("AxisTensor", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kInt32), - DATALAYOUT(kNCHW))}) - .BindOutput("Out", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNCHW))}) - .Finalize(); +// REGISTER_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, Concat_buffer, def) +// .BindInput("X", +// {LiteType::GetTensorTy(TARGET(kOpenCL), +// PRECISION(kFloat), +// DATALAYOUT(kNCHW))}) +// .BindInput("AxisTensor", +// {LiteType::GetTensorTy(TARGET(kOpenCL), +// PRECISION(kInt32), +// DATALAYOUT(kNCHW))}) +// .BindOutput("Out", +// {LiteType::GetTensorTy(TARGET(kOpenCL), +// PRECISION(kFloat), +// DATALAYOUT(kNCHW))}) +// .Finalize(); diff --git a/lite/kernels/opencl/concat_compute_test.cc b/lite/kernels/opencl/concat_compute_test.cc index 37e7b6658be2eaa60285474b3766ce462ea3779b..9af0666cc9bdef184654a026bbfb6004c2ccdd18 100644 --- a/lite/kernels/opencl/concat_compute_test.cc +++ b/lite/kernels/opencl/concat_compute_test.cc @@ -73,7 +73,7 @@ void concat_mul_compute_ref(std::vector ins_data, } } } -#if 1 // concat_buffer +#if 0 // concat_buffer TEST(opencl_concat_buffer, compute) { // prepare data const DDim x0_dim = DDim(std::vector{1, 2, 3, 4}); @@ -382,7 +382,7 @@ TEST(concat_image2d_fp32, compute) { } // namespace paddle // concat buffer -USE_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, def); +// USE_LITE_KERNEL(concat, kOpenCL, kFloat, kNCHW, def); // concat image2d fp32 USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); diff --git a/lite/kernels/opencl/conv2d_1x1_compute.cc b/lite/kernels/opencl/conv2d_1x1_compute.cc deleted file mode 100644 index 975105fd41cd0b0224b760222a3e08a5ea4601aa..0000000000000000000000000000000000000000 --- a/lite/kernels/opencl/conv2d_1x1_compute.cc +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "lite/backends/opencl/cl_include.h" -#include "lite/core/kernel.h" -#include "lite/core/op_registry.h" -#include "lite/kernels/opencl/image_helper.h" -#include "lite/operators/op_params.h" -#include "lite/utils/replace_stl/stream.h" - -namespace paddle { -namespace lite { -namespace kernels { -namespace opencl { - -#define USE_BUFFER_FOR_CONV1x1_BIAS -class Conv2d1x1Image2DCompute : public KernelLite { - public: - using param_t = operators::ConvParam; - - void PrepareForRun() override { - const auto& param = *param_.get_mutable(); - if (param.fuse_relu) { - build_options_ += " -DRELU"; - } - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - if (has_bias) { - build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH"; - } - auto& context = ctx_->As(); - if (param.x->dims()[1] % 4 == 0) { - context.cl_context()->AddKernel(kernel_func_name_simple_, - "image/conv2d_1x1_kernel.cl", - build_options_); - } else { - context.cl_context()->AddKernel( - kernel_func_name_, "image/conv2d_1x1_kernel.cl", build_options_); - } - } - - void Run() override { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = param.filter->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ conv2d_1x1 params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - // handle bias use buffer for channel wise , use image for element wise - const cl::Buffer* bias_buf = nullptr; - const cl::Image2D* bias_image = nullptr; - if (has_bias) { -#ifndef USE_BUFFER_FOR_CONV1x1_BIAS - is_element_wise_bias - ? (bias_image = param.bias->data()) - : (bias_buf = param.bias->data()); -#else - bias_image = param.bias->data(); -#endif - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - if (input_dims[1] % 4 == 0) { - kernel_key << kernel_func_name_simple_ << build_options_; - } else { - kernel_key << kernel_func_name_ << build_options_; - } - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - int maped_w = maptofactor(w, 4); - - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "maped_w: " << maped_w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, maped_w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { -#ifndef USE_BUFFER_FOR_CONV1x1_BIAS - if (is_element_wise_bias != 0) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - } else { - VLOG(4) << "set bias_buf: "; - status = kernel.setArg(++arg_idx, *bias_buf); - } -#else - status = kernel.setArg(++arg_idx, *bias_image); -#endif - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(maped_w), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); - } - - private: - std::string kernel_func_name_{"conv2d_1x1"}; - std::string kernel_func_name_simple_{"conv2d_1x1_simple"}; - std::string build_options_{"-DCL_DTYPE_float"}; - std::shared_ptr event_{new cl::Event}; -}; - -} // namespace opencl -} // namespace kernels -} // namespace lite -} // namespace paddle - -REGISTER_LITE_KERNEL(conv2d_1x1, - kOpenCL, - kFloat, - kImageDefault, - paddle::lite::kernels::opencl::Conv2d1x1Image2DCompute, - image2d) - .BindInput("Input", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault))}) - .BindInput("Bias", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault))}) - .BindInput("Filter", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageNW))}) - .BindOutput("Output", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault))}) - .Finalize(); diff --git a/lite/kernels/opencl/conv2d_1x1_compute_test.cc b/lite/kernels/opencl/conv2d_1x1_compute_test.cc deleted file mode 100644 index 6879d2ba38cdb98cd1dd4df8fe2f3b3c90cc22f2..0000000000000000000000000000000000000000 --- a/lite/kernels/opencl/conv2d_1x1_compute_test.cc +++ /dev/null @@ -1,413 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include - -#include "lite/backends/opencl/cl_image_converter.h" -#include "lite/backends/opencl/target_wrapper.h" -#include "lite/core/op_registry.h" -#include "lite/core/tensor.h" -#include "lite/utils/logging.h" - -namespace paddle { -namespace lite { - -template -static void conv_basic(const Dtype1* din, - Dtype2* dout, - int num, - int chout, - int hout, - int wout, - int chin, - int hin, - int win, - const Dtype1* weights, - const Dtype2* bias, - int group, - int kernel_w, - int kernel_h, - int stride_w, - int stride_h, - int dila_w, - int dila_h, - int pad_w, - int pad_h, - bool flag_bias, - bool flag_relu) { - Dtype2 beta = 0; - auto src_data = din; - auto dst_data_ref = dout; - auto weights_data = weights; - auto with_bias = flag_bias; - auto bias_data = bias; - - int in_num = num; - int out_channels = chout; - int out_h = hout; - int out_w = wout; - - int in_channel = chin; - int in_h = hin; - int in_w = win; - int out_c_group = out_channels / group; - int in_c_group = in_channel / group; - - for (int n = 0; n < in_num; ++n) { - for (int g = 0; g < group; ++g) { - for (int oc = 0; oc < out_c_group; ++oc) { - for (int oh = 0; oh < out_h; ++oh) { - for (int ow = 0; ow < out_w; ++ow) { - int out_idx = n * group * out_c_group * out_h * out_w + - g * out_c_group * out_h * out_w + oc * out_h * out_w + - oh * out_w + ow; - Dtype2 bias_d = - with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0; - dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta; - for (int ic = 0; ic < in_c_group; ++ic) { - for (int kh = 0; kh < kernel_h; ++kh) { - for (int kw = 0; kw < kernel_w; ++kw) { - int iw = ow * stride_w - pad_w + kw * (dila_w); - int ih = oh * stride_h - pad_h + kh * (dila_h); - if (iw < 0 || iw >= in_w) continue; - if (ih < 0 || ih >= in_h) continue; - - int iidx = n * in_channel * in_h * in_w + - g * in_c_group * in_h * in_w + ic * in_h * in_w + - ih * in_w + iw; - int widx = - g * out_c_group * in_c_group * kernel_h * kernel_w + - oc * in_c_group * kernel_h * kernel_w + - ic * kernel_h * kernel_w + kh * kernel_w + kw; - - dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx]; - } - } - } - if (flag_relu) { - dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 - ? dst_data_ref[out_idx] - : (Dtype2)0; - } - } - } - } - } - } -} -TEST(conv2d_1x1, compute) { - // conv infos - const int ksize = 1; - const int stride = 1; - const int pad = 0; - const int group = 1; - const int dilation = 0; -// int loop_cnt = 0; - -#ifdef LOOP_TEST - for (int batch_size = 1; batch_size < 4; ++batch_size) { - for (int oc = 4; oc < 10; oc += 1) { // oc - for (int ih = 4; ih < 9; ih += 1) { // ih - /*int iw = ih;*/ for (int iw = 4; iw < 10; iw += 1) { // iw - for (int ic = 4; ic < 10; ic += 1) { // ic - for (bool bias_flag : {true, false}) { - for (bool relu_flag : {true, false}) { -#else - const int batch_size = 1; - const int oc = 4; - const int ih = 8; - const int iw = 8; - const int ic = 4; - const bool bias_flag = false; - const bool relu_flag = false; -#endif - const int oh = ih; - const int ow = iw; - - VLOG(4) << "to get kernel ..."; - auto kernels = - KernelRegistry::Global().Create("conv2d_1x1", - TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault)); - ASSERT_FALSE(kernels.empty()); - - auto kernel = std::move(kernels.front()); - VLOG(4) << "created conv2d_1x1 kernel"; - - VLOG(4) << "prepare kernel ------"; - - lite::Tensor input, filter, bias, output; - operators::ConvParam param; - param.x = &input; - param.filter = &filter; - param.output = &output; - if (bias_flag) { - param.bias = &bias; - } - param.fuse_relu = relu_flag; - - std::vector paddings = {pad, pad, pad, pad}; - std::vector dilations = {dilation, dilation}; - - param.paddings = std::make_shared>(paddings); - param.dilations = std::make_shared>(dilations); - param.strides = std::vector{stride, stride}; - - std::unique_ptr context(new KernelContext); - context->As().InitOnce(); - - std::unique_ptr conv_1x1_context( - new KernelContext); - context->As().CopySharedTo( - &(conv_1x1_context->As())); - kernel->SetContext(std::move(conv_1x1_context)); - - const DDim& input_dim = - lite::DDim{std::vector({batch_size, ic, ih, iw})}; - - const DDim& filter_dim = - lite::DDim{std::vector({oc, ic, ksize, ksize})}; - const DDim& out_dim = - lite::DDim{std::vector({batch_size, oc, ih, iw})}; - // element wise bias - const DDim& bias_dim = lite::DDim{std::vector({oc})}; - - param.x->Resize(input_dim); - param.filter->Resize(filter_dim); - param.output->Resize(out_dim); - if (bias_flag) { - param.bias->Resize(bias_dim); - } - - kernel->SetParam(param); - - size_t input_image_width = iw * ((ic + 3) / 4); - size_t input_image_height = ih * batch_size; - - size_t out_image_width = ow * ((oc + 3) / 4); - size_t out_image_height = oh * batch_size; - - size_t bias_image_width = ow * ((oc + 3) / 4); - size_t bias_image_height = oh * batch_size; - - size_t filter_image_width = ksize * ((oc + 3) / 4); - size_t filter_image_height = ic * ksize; - - const size_t cl_image2d_row_pitch{0}; - const size_t cl_image2d_slice_pitch{0}; - - std::default_random_engine engine; - std::uniform_real_distribution gen(-5, 5); - - std::vector input_v(batch_size * ic * ih * iw); - std::vector filter_v(oc * ic * ksize * ksize); - std::vector output_v(batch_size * oc * ih * iw); - std::vector bias_v(oc); - - VLOG(4) << "gen input and filter ..."; - - for (auto& i : input_v) { - i = gen(engine); - } - for (auto& f : filter_v) { - f = gen(engine); - } - - VLOG(4) << "after gen input and filter ..."; - VLOG(4) << "input_v.size(): " << input_v.size(); - VLOG(4) << "filter_v.size(): " << filter_v.size(); - VLOG(4) << "output_v.size(): " << output_v.size(); - VLOG(4) << "bias_v.size(): " << bias_v.size(); - VLOG(4) << "input_dim.production(): " << input_dim.production(); - VLOG(4) << "filter_dim.production(): " - << filter_dim.production(); - VLOG(4) << "out_dim.production(): " << out_dim.production(); - VLOG(4) << "bias_dim.production(): " << bias_dim.production(); - VLOG(4) << "4 * input_image_height * input_image_width: " - << 4 * input_image_height * input_image_width; - VLOG(4) << "4 * filter_image_width * filter_image_height: " - << 4 * filter_image_width * filter_image_height; - - CHECK(input_dim.production() == input_v.size()); - CHECK_LE(input_dim.production(), - 4 * input_image_height * input_image_width); - CHECK(filter_dim.production() == filter_v.size()); - CHECK_LE(filter_dim.production(), - 4 * filter_image_width * filter_image_height); - - paddle::lite::CLImageConverterDefault default_convertor; - VLOG(4) << "set mapped input ..."; - std::vector x_image_v( - input_image_width * input_image_height * 4); // 4 : RGBA - std::vector filter_image_v( - filter_image_width * filter_image_height * 4); // 4 : RGBA - std::vector bias_image_v( - bias_image_width * bias_image_height * 4); // 4 : RGBA - std::vector out_image_v( - out_image_width * out_image_height * 4); // 4 : RGBA - - default_convertor.NCHWToImage( - input_v.data(), x_image_v.data(), input_dim); - - /* for (int j = 0; j < input_v.size(); j += 1) { - // VLOG(4) << "input_v - input[" << j << "]: - // " << input_v.data()[j]; - std::cout << j << " " << input_v.data()[j] << - std::endl; - } - std::cout << std::endl; - - for (int j = 0; j < x_image_v.size(); j += 1) { - // VLOG(4) << "x_image_v - input[" << j << - // "]: " << - x_image_v.data()[j]; - std::cout << j << " " << x_image_v.data()[j] - << std::endl; - }*/ - - VLOG(4) << "set mapped filter ..."; - paddle::lite::CLImageConverterNWBlock nw_convertor; - nw_convertor.NCHWToImage( - filter_v.data(), filter_image_v.data(), filter_dim); - - auto* input_image2d = input.mutable_data( - input_image_width, input_image_height, x_image_v.data()); - auto* filter_image2d = filter.mutable_data( - filter_image_width, - filter_image_height, - filter_image_v.data()); - - if (bias_flag) { - nw_convertor.NCHWToImage( - filter_v.data(), filter_image_v.data(), filter_dim); - - for (int i = 0; i < bias_dim.production(); ++i) { - bias_v[i] = static_cast(gen(engine)); - } - CLImageConverterFolder folder_convertor; - folder_convertor.NCHWToImage( - bias_v.data(), bias_image_v.data(), bias_dim); - auto* bias_data = bias.mutable_data( - bias_image_width, bias_image_height, bias_image_v.data()); - } - - VLOG(4) << "resize output ..."; - output.Resize(out_dim); - - // cpu conv basic calc - lite::Tensor out_ref; - out_ref.Resize(out_dim); - - VLOG(4) << "prepare kernel ready"; - - VLOG(4) << "kernel launch ..."; - kernel->Launch(); - VLOG(4) << "mutable output ..."; - auto* output_image2d = output.mutable_data( - out_image_width, out_image_height); - - auto* wait_list = context->As().cl_wait_list(); - auto* out_ptr = param.output->data(); - auto it = wait_list->find(out_ptr); - - if (it != wait_list->end()) { - VLOG(4) << "--- Find the sync event for the target cl " - "tensor. ---"; - auto& event = *(it->second); - event.wait(); - } else { - LOG(FATAL) << "Could not find the sync event for the target " - "cl tensor."; - } - - TargetWrapperCL::ImgcpySync(out_image_v.data(), - output.data(), - out_image_width, - out_image_height, - cl_image2d_row_pitch, - cl_image2d_slice_pitch, - IoDirection::DtoH); - - DDim out_image_shape = - default_convertor.InitImageDimInfoWith(output.dims()); - - default_convertor.ImageToNCHW(out_image_v.data(), - output_v.data(), - out_image_shape, - output.dims()); - VLOG(4) << "mutable_data out_ref_data: "; - - // run cpu ref - auto* out_ref_data = out_ref.mutable_data(TARGET(kARM)); - - VLOG(4) << " conv_basic beigin ..... "; - - conv_basic(input_v.data(), - out_ref_data, - batch_size, - oc, - oh, - ow, - ic, - ih, - iw, - filter_v.data(), - bias_v.data(), // mapped_bias, - group, - ksize, - ksize, - stride, - stride, - dilation, - dilation, - pad, - pad, - bias_flag, - relu_flag); - VLOG(4) << " conv_basic end ..... "; - - VLOG(4) << " out_dim: " << out_dim; - const DDim& out_image_dims = lite::DDim{std::vector( - {static_cast(out_image_width), - static_cast(out_image_height)})}; - - for (int i = 0; i < out_dim.production(); i++) { - EXPECT_NEAR(output_v[i], out_ref_data[i], 1e-2); - if (abs(output_v[i] - out_ref_data[i]) > 1e-2) { - LOG(FATAL) << "error idx:" << i; - } - } - -#ifdef LOOP_TEST - } - } - } - } - } - } - } -#else -// nothing to do. -#endif -} - -} // namespace lite -} // namespace paddle - -USE_LITE_KERNEL(conv2d_1x1, kOpenCL, kFloat, kImageDefault, image2d); diff --git a/lite/kernels/opencl/conv_compute.cc b/lite/kernels/opencl/conv_compute.cc index c3d3e2a6c27f794268ef42ac97ab492ddd4e9de1..a9cfb32aa5b141b4e3b3c7d28d2d3694524fa34c 100644 --- a/lite/kernels/opencl/conv_compute.cc +++ b/lite/kernels/opencl/conv_compute.cc @@ -362,6 +362,40 @@ void ConvImageCompute::PrepareForRun() { filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); impl_ = &ConvImageCompute::Conv2d1x1; + } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] && + kernel_h == 3 && kernel_w == 3 && groups > 1) { + // depth_conv2d_3x3s1, depth_conv2d_3x3 + if (stride_h == 1 && dilations[0] == 1) { + kernel_func_names_.push_back("depth_conv2d_3x3s1"); + impl_ = &ConvImageCompute::DepthwiseConv2d3x3s1; + } else { + kernel_func_names_.push_back("depth_conv2d_3x3"); + impl_ = &ConvImageCompute::DepthwiseConv2d3x3; + } + kernel_func_paths_.push_back("image/depthwise_conv2d_kernel.cl"); + + CLImageConverterDWBlock converter; + const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); + std::vector filter_image_v(filter_image_dims[0] * + filter_image_dims[1] * 4); // 4 : RGBA + converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); + filter_gpu_image_.mutable_data( + filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); + } else if (filter_dims[1] == 1 && x_dims[1] == output_dims[1] && + kernel_h != 3 && groups > 1) { + // depth_conv2d + kernel_func_names_.push_back("depth_conv2d"); + kernel_func_paths_.push_back("image/depthwise_conv2d_basic_kernel.cl"); + + CLImageConverterDWBlock converter; + const DDim& filter_image_dims = converter.InitImageDimInfoWith(filter_dims); + std::vector filter_image_v(filter_image_dims[0] * + filter_image_dims[1] * 4); // 4 : RGBA + converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims); + filter_gpu_image_.mutable_data( + filter_image_dims[0], filter_image_dims[1], filter_image_v.data()); + + impl_ = &ConvImageCompute::DepthwiseConv2d; } else if (kernel_h == 3 && kernel_h == 3) { // conv2d_3x3 kernel_func_names_.push_back("conv2d_3x3"); @@ -407,6 +441,8 @@ void ConvImageCompute::PrepareForRun() { } else { LOG(FATAL) << "conv image compute not support this condition yet! "; } + VLOG(1) << "kernel_func_names_[0]:" << kernel_func_names_[0] + << " kernel_func_paths_[0]:" << kernel_func_paths_[0]; std::string build_options_single(" -DCL_DTYPE_float"); // relu options @@ -1064,6 +1100,326 @@ void ConvImageCompute::Conv2d7x7() { context.cl_wait_list()->emplace(out_image, event_); } +void ConvImageCompute::DepthwiseConv2d3x3s1() { + const auto& param = *param_.get_mutable(); + auto x_dims = param.x->dims(); + auto filter_dims = param.filter->dims(); + auto output_dims = param.output->dims(); + auto paddings = *param.paddings; + auto strides = param.strides; + auto dilations = *param.dilations; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + auto* input_img = param.x->data(); + auto* filter_img = filter_gpu_image_.data(); + + const cl::Image2D* bias_img = nullptr; + if (param.bias) { + bias_img = bias_gpu_image_.data(); + } + + auto image_shape = InitImageDimInfoWith(output_dims); + + auto* output_img = param.output->mutable_data( + image_shape["width"], image_shape["height"]); + + STL::stringstream kernel_key; + kernel_key << kernel_func_names_[0] << build_options_[0]; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int c_block = (output_dims[1] + 3) / 4; + int w = output_dims[3]; + int nh = output_dims[0] * output_dims[2]; + + int w_blk_size = 2; + int w_blk = (w + w_blk_size - 1) / w_blk_size; + + auto global_work_size = cl::NDRange(c_block, w_blk, nh); + + cl_int status; + int arg_idx = 0; + status = kernel.setArg(arg_idx, static_cast(c_block)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(w_blk)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(nh)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *input_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *filter_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *output_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(strides[0])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(paddings[0])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(dilations[0])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(x_dims[1])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); + CL_CHECK_FATAL(status); + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(output_img, event_); +} + +void ConvImageCompute::DepthwiseConv2d3x3() { + const auto& param = *param_.get_mutable(); + auto x_dims = param.x->dims(); + auto filter_dims = param.filter->dims(); + auto output_dims = param.output->dims(); + auto paddings = *param.paddings; + auto strides = param.strides; + auto dilations = *param.dilations; + int offset = filter_dims[2] / 2 - paddings[0]; + int input_c_block = (x_dims[1] + 3) / 4; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + auto* input_img = param.x->data(); + auto* filter_img = filter_gpu_image_.data(); + + const cl::Image2D* bias_img = nullptr; + if (param.bias) { + bias_img = bias_gpu_image_.data(); + } + + auto image_shape = InitImageDimInfoWith(output_dims); + + auto* output_img = param.output->mutable_data( + image_shape["width"], image_shape["height"]); + + STL::stringstream kernel_key; + kernel_key << kernel_func_names_[0] << build_options_[0]; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int c_block = (output_dims[1] + 3) / 4; + int w = output_dims[3]; + int nh = output_dims[0] * output_dims[2]; + auto global_work_size = cl::NDRange(c_block, w, nh); + + VLOG(4) << "setArg"; + VLOG(4) << "c_block = " << c_block; + VLOG(4) << "w = " << w; + VLOG(4) << "nh = " << nh; + + VLOG(4) << "strides = " << strides[0]; + VLOG(4) << "offset = " << offset; + VLOG(4) << "dilations = " << dilations[0]; + VLOG(4) << "input_c_block = " << input_c_block; + VLOG(4) << "x_dims[3] = " << x_dims[3]; + VLOG(4) << "x_dims[2] = " << x_dims[2]; + VLOG(4) << "output_dims[3] = " << output_dims[3]; + VLOG(4) << "output_dims[2] = " << output_dims[2]; + + cl_int status; + int arg_idx = 0; + status = kernel.setArg(arg_idx, static_cast(c_block)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(w)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(nh)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *input_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *filter_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *output_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(strides[0])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(offset)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(dilations[0])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(input_c_block)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); + CL_CHECK_FATAL(status); + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(output_img, event_); +} + +void ConvImageCompute::DepthwiseConv2d() { + const auto& param = *param_.get_mutable(); + auto input_dims = param.x->dims(); + auto paddings = *param.paddings; + auto strides = param.strides; + auto* input_image = param.x->data(); + auto* filter_image = filter_gpu_image_.data(); + auto filter_dims = param.filter->dims(); + auto output_dims = param.output->dims(); + + int input_width = input_dims[3]; + int input_height = input_dims[2]; + int output_width = output_dims[3]; + int output_height = output_dims[2]; + int filter_width = filter_dims[3]; + int filter_height = filter_dims[2]; + auto out_image_shape = InitImageDimInfoWith(output_dims); + auto* out_image = param.output->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + + const bool has_bias = param.bias != nullptr; + const bool is_element_wise_bias = + has_bias && param.output->dims() == param.bias->dims(); + int offset = static_cast(param.filter->dims()[2]) / 2 - + static_cast(paddings[0]); + + // calc input_c_block + auto input_image_shape = InitImageDimInfoWith(input_dims); + int input_c_block = input_image_shape["width"] / input_dims[3]; + int input_c = input_dims[1]; + auto dilations = *param.dilations; + + const std::vector& default_work_size = + DefaultWorkSize(output_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + + VLOG(4) << "============ depthwise conv2d params ============"; + VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," + << input_image_shape["height"]; + VLOG(4) << "input_c_block: " << input_c_block; + VLOG(4) << "input_c: " << input_c; + VLOG(4) << "input_image: " << input_image; + VLOG(4) << "filter_dims: " << filter_dims; + VLOG(4) << "filter_image: " << filter_image; + VLOG(4) << "output_dims: " << output_dims; + VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " + << out_image_shape["height"]; + VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; + VLOG(4) << "has bias: " << has_bias; + VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; + VLOG(4) << "strides: " << strides[0] << "," << strides[1]; + VLOG(4) << "offset: " << offset; + VLOG(4) << "dilations.size : " << dilations.size(); + VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; + VLOG(4) << "default work size{c_block, w, nh}: " + << "{" << c_block << ", " << w << ", " << nh << "" + << "}"; + + CHECK_GE(dilations.size(), 2); + CHECK(dilations[0] == dilations[1]); + CHECK_GE(input_dims.size(), 4); + CHECK_GE(paddings.size(), 2); + CHECK(paddings[0] == paddings[1]); + CHECK_GE(strides.size(), 2); + CHECK(strides[0] == strides[1]); + + // handle bias use buffer for channel wise , use image for element wise + const cl::Buffer* bias_buf = nullptr; + const cl::Image2D* bias_image = nullptr; + if (has_bias) { + bias_image = bias_gpu_image_.data(); + } + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_names_[0] << build_options_[0]; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + VLOG(4) << "kernel_key: " << kernel_key.str(); + VLOG(4) << "kernel ready ... " << kernel_key.str(); + VLOG(4) << "w: " << w; + + cl_int status; + int arg_idx = 0; + status = kernel.setArg(arg_idx, c_block); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, w); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, nh); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *input_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *filter_image); + CL_CHECK_FATAL(status); + if (has_bias) { + VLOG(4) << "set bias_image: "; + status = kernel.setArg(++arg_idx, *bias_image); + CL_CHECK_FATAL(status); + } + status = kernel.setArg(++arg_idx, *out_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, strides[0]); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, offset); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_c_block); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, dilations[0]); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, input_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, output_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_width); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, filter_height); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size.data()[0]), + static_cast(default_work_size.data()[1]), + static_cast(default_work_size.data()[2])}; + + VLOG(4) << "out_image: " << out_image; + VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," + << global_work_size[1] << "," << global_work_size[2] << "}"; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_image, event_); +} + void ConvImageCompute::Run() { (this->*impl_)(); } } // namespace opencl @@ -1071,19 +1427,37 @@ void ConvImageCompute::Run() { (this->*impl_)(); } } // namespace lite } // namespace paddle +// REGISTER_LITE_KERNEL(conv2d, +// kOpenCL, +// kFloat, +// kNCHW, +// paddle::lite::kernels::opencl::ConvCompute, +// def) +// .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))}) +// .Finalize(); + REGISTER_LITE_KERNEL(conv2d, kOpenCL, kFloat, - kNCHW, - paddle::lite::kernels::opencl::ConvCompute, - def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))}) - .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + kImageDefault, + paddle::lite::kernels::opencl::ConvImageCompute, + image2d) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) .Finalize(); -REGISTER_LITE_KERNEL(conv2d, +REGISTER_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kImageDefault, diff --git a/lite/kernels/opencl/conv_compute.h b/lite/kernels/opencl/conv_compute.h index d5dd65cdc855ebc25624e8316866a5944a2418b8..672ba9d223031edf1ebc3d955908c4ab8edc0834 100644 --- a/lite/kernels/opencl/conv_compute.h +++ b/lite/kernels/opencl/conv_compute.h @@ -74,6 +74,9 @@ class ConvImageCompute : public KernelLite kernel_func_names_{}; diff --git a/lite/kernels/opencl/conv_compute_test.cc b/lite/kernels/opencl/conv_compute_test.cc index 1c7cca63ae4d1c0a5183b512827f4b6943f994af..af59873336fb154b34d7ada398d7fe8e568e7655 100644 --- a/lite/kernels/opencl/conv_compute_test.cc +++ b/lite/kernels/opencl/conv_compute_test.cc @@ -166,6 +166,8 @@ void PrintData(std::string name, } } +// buffer +#if 0 // #define PRINT_RESULT #define LOOP_TEST TEST(conv2d, compute_conv2d_1x1) { @@ -623,8 +625,9 @@ TEST(conv2d, compute_conv2d_gemm) { } // batch_size #endif } +#endif } // namespace lite } // namespace paddle -USE_LITE_KERNEL(conv2d, kOpenCL, kFloat, kNCHW, def); +// USE_LITE_KERNEL(conv2d, kOpenCL, kFloat, kNCHW, def); diff --git a/lite/kernels/opencl/conv_image2d_compute_test.cc b/lite/kernels/opencl/conv_image2d_compute_test.cc index 3e698a4ae838a74882317014df42cee9d2c7961c..4c81978b405e3acb4bc0e3ecc44b1ec10ac903b7 100644 --- a/lite/kernels/opencl/conv_image2d_compute_test.cc +++ b/lite/kernels/opencl/conv_image2d_compute_test.cc @@ -559,9 +559,11 @@ TEST(conv2d, compute_image2d_3x3) { // element wise bias const DDim& bias_dim = lite::DDim{std::vector({oc})}; - LOG(INFO) << "input_dim:" << input_dim - << " filter_dim:" << filter_dim - << " out_dim:" << out_dim; + VLOG(2) << "input_dim:" << input_dim + << " filter_dim:" << filter_dim << " out_dim:" << out_dim + << " bias_flag:" << bias_flag << " bias_dim:" << bias_dim + << " group:" << group << " stride:" << stride + << " pad:" << pad << " dilation:" << dilation; param.x->Resize(input_dim); param.filter->Resize(filter_dim); @@ -902,6 +904,12 @@ TEST(conv2d, compute_image2d_5x5) { // element wise bias const DDim& bias_dim = lite::DDim{std::vector({oc})}; + VLOG(2) << "input_dim:" << input_dim + << " filter_dim:" << filter_dim << " out_dim:" << out_dim + << " bias_flag:" << bias_flag << " bias_dim:" << bias_dim + << " group:" << group << " stride:" << stride + << " pad:" << pad << " dilation:" << dilation; + param.x->Resize(input_dim); param.filter->Resize(filter_dim); param.output->Resize(out_dim); diff --git a/lite/kernels/opencl/depthwise_conv2d_compute.cc b/lite/kernels/opencl/depthwise_conv2d_compute.cc index 554cc87c5f21e283316df402d195ec8bf8c4d738..0c88509926041411eddac66bea08b5d3a08d6a3c 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute.cc @@ -123,420 +123,6 @@ class DepthwiseConv2dCompute std::shared_ptr event_{new cl::Event}; }; -class DepthwiseConv2dComputeFP16Image - : public KernelLite { - public: - using param_t = operators::ConvParam; - - std::string doc() const override { - return "DepthwiseConv2d using cl::Image2D/kImageDefault, kFP16"; - } - - void PrepareForRun() override { - const auto& param = *param_.get_mutable(); - if (param.fuse_relu) { - build_options_ += " -DRELU"; - } else if (param.activation_param.active_type == - lite_api::ActivationType::kRelu6) { - build_options_ += " -DRELU6"; - } - auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/depthwise_conv2d_kernel.cl", build_options_); - } - - void Run() override { - const auto& param = *param_.get_mutable(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - int offset = filter_dims[2] / 2 - paddings[0]; - int input_c_block = (x_dims[1] + 3) / 4; - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* input_img = param.x->data(); - auto* filter_img = param.filter->data(); - - auto* bias_img = param.bias == nullptr - ? static_cast(nullptr) - : param.bias->data(); - - auto image_shape = InitImageDimInfoWith(output_dims); - - auto* output_img = param.output->mutable_data( - image_shape["width"], image_shape["height"]); - - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - - int c_block = (output_dims[1] + 3) / 4; - int w = output_dims[3]; - int nh = output_dims[0] * output_dims[2]; - auto global_work_size = cl::NDRange(c_block, w, nh); - - VLOG(4) << "setArg"; - VLOG(4) << "c_block = " << c_block; - VLOG(4) << "w = " << w; - VLOG(4) << "nh = " << nh; - - VLOG(4) << "strides = " << strides[0]; - VLOG(4) << "offset = " << offset; - VLOG(4) << "dilations = " << dilations[0]; - VLOG(4) << "input_c_block = " << input_c_block; - VLOG(4) << "x_dims[3] = " << x_dims[3]; - VLOG(4) << "x_dims[2] = " << x_dims[2]; - VLOG(4) << "output_dims[3] = " << output_dims[3]; - VLOG(4) << "output_dims[2] = " << output_dims[2]; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, static_cast(c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(w)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(nh)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *output_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(strides[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(offset)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(dilations[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(input_c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); - CL_CHECK_FATAL(status); - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(output_img, event_); - } - - private: - std::string kernel_func_name_{"depth_conv2d_3x3"}; - std::string build_options_{"-DCL_DTYPE_half"}; - std::shared_ptr event_{new cl::Event}; -}; - -class DepthwiseConv2d3x3s1ComputeFP16Image - : public KernelLite { - public: - using param_t = operators::ConvParam; - - std::string doc() const override { - return "DepthwiseConv2d3x3s1 using cl::Image2D/kImageDefault, kFP16"; - } - - void PrepareForRun() override { - const auto& param = *param_.get_mutable(); - if (param.fuse_relu) { - build_options_ += " -DRELU"; - } else if (param.activation_param.active_type == - lite_api::ActivationType::kRelu6) { - build_options_ += " -DRELU6"; - } - auto& context = ctx_->As(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/depthwise_conv2d_kernel.cl", build_options_); - } - - void Run() override { - const auto& param = *param_.get_mutable(); - auto x_dims = param.x->dims(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto dilations = *param.dilations; - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - auto* input_img = param.x->data(); - auto* filter_img = param.filter->data(); - - auto* bias_img = param.bias == nullptr - ? static_cast(nullptr) - : param.bias->data(); - - auto image_shape = InitImageDimInfoWith(output_dims); - - auto* output_img = param.output->mutable_data( - image_shape["width"], image_shape["height"]); - - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - - int c_block = (output_dims[1] + 3) / 4; - int w = output_dims[3]; - int nh = output_dims[0] * output_dims[2]; - - int w_blk_size = 2; - int w_blk = (w + w_blk_size - 1) / w_blk_size; - - auto global_work_size = cl::NDRange(c_block, w_blk, nh); - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, static_cast(c_block)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(w_blk)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(nh)); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *output_img); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(strides[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(paddings[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(dilations[0])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[1])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(x_dims[2])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[3])); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, static_cast(output_dims[2])); - CL_CHECK_FATAL(status); - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(output_img, event_); - } - - private: - std::string kernel_func_name_{"depth_conv2d_3x3s1"}; - std::string build_options_{"-DCL_DTYPE_half"}; - std::shared_ptr event_{new cl::Event}; -}; - -class DepthwiseConv2dBasicComputeFP32Image - : public KernelLite { - public: - using param_t = operators::ConvParam; - - std::string doc() const override { - return "DepthwiseConv2d basic using cl::Image2D/kImageDefault, kFloat32"; - } - - void PrepareForRun() override { - const auto& param = *param_.get_mutable(); - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - if (param.fuse_relu) { - build_options_ += " -DRELU"; - } else if (param.activation_param.active_type == - lite_api::ActivationType::kRelu6) { - build_options_ += " -DRELU6"; - } - if (has_bias) { - build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH"; - } - auto& context = ctx_->As(); - context.cl_context()->AddKernel(kernel_func_name_, - "image/depthwise_conv2d_basic_kernel.cl", - build_options_); - } - - void Run() override { - const auto& param = *param_.get_mutable(); - auto input_dims = param.x->dims(); - auto paddings = *param.paddings; - auto strides = param.strides; - auto* input_image = param.x->data(); - auto* filter_image = param.filter->data(); - auto filter_dims = param.filter->dims(); - auto output_dims = param.output->dims(); - - int input_width = input_dims[3]; - int input_height = input_dims[2]; - int output_width = output_dims[3]; - int output_height = output_dims[2]; - int filter_width = filter_dims[3]; - int filter_height = filter_dims[2]; - auto out_image_shape = InitImageDimInfoWith(output_dims); - auto* out_image = param.output->mutable_data( - out_image_shape["width"], out_image_shape["height"]); - - const bool has_bias = param.bias != nullptr; - const bool is_element_wise_bias = - has_bias && param.output->dims() == param.bias->dims(); - int offset = static_cast(param.filter->dims()[2]) / 2 - - static_cast(paddings[0]); - - // calc input_c_block - auto input_image_shape = InitImageDimInfoWith(input_dims); - int input_c_block = input_image_shape["width"] / input_dims[3]; - int input_c = input_dims[1]; - auto dilations = *param.dilations; - - const std::vector& default_work_size = - DefaultWorkSize(output_dims, - DDim(std::vector{ - static_cast(out_image_shape["width"]), - static_cast(out_image_shape["height"])})); - - int c_block = default_work_size[0]; - int w = default_work_size[1]; - int nh = default_work_size[2]; - - VLOG(4) << "============ depthwise conv2d params ============"; - VLOG(4) << "input_image_shape: " << input_image_shape["width"] << "," - << input_image_shape["height"]; - VLOG(4) << "input_c_block: " << input_c_block; - VLOG(4) << "input_c: " << input_c; - VLOG(4) << "input_image: " << input_image; - VLOG(4) << "filter_dims: " << filter_dims; - VLOG(4) << "filter_image: " << filter_image; - VLOG(4) << "output_dims: " << output_dims; - VLOG(4) << "out_image_shape: " << out_image_shape["width"] << ", " - << out_image_shape["height"]; - VLOG(4) << "paddings: " << paddings[0] << "," << paddings[1]; - VLOG(4) << "has bias: " << has_bias; - VLOG(4) << "is_element_wise_bias : " << is_element_wise_bias; - VLOG(4) << "strides: " << strides[0] << "," << strides[1]; - VLOG(4) << "offset: " << offset; - VLOG(4) << "dilations.size : " << dilations.size(); - VLOG(4) << "dilations: " << dilations[0] << ", " << dilations[1]; - VLOG(4) << "default work size{c_block, w, nh}: " - << "{" << c_block << ", " << w << ", " << nh << "" - << "}"; - - CHECK_GE(dilations.size(), 2); - CHECK(dilations[0] == dilations[1]); - CHECK_GE(input_dims.size(), 4); - CHECK_GE(paddings.size(), 2); - CHECK(paddings[0] == paddings[1]); - CHECK_GE(strides.size(), 2); - CHECK(strides[0] == strides[1]); - - // handle bias use buffer for channel wise , use image for element wise - const cl::Buffer* bias_buf = nullptr; - const cl::Image2D* bias_image = nullptr; - if (has_bias) { - bias_image = param.bias->data(); - } - - auto& context = ctx_->As(); - CHECK(context.cl_context() != nullptr); - STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; - auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - VLOG(4) << "kernel_key: " << kernel_key.str(); - VLOG(4) << "kernel ready ... " << kernel_key.str(); - VLOG(4) << "w: " << w; - - cl_int status; - int arg_idx = 0; - status = kernel.setArg(arg_idx, c_block); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, w); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, nh); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *input_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, *filter_image); - CL_CHECK_FATAL(status); - if (has_bias) { - VLOG(4) << "set bias_image: "; - status = kernel.setArg(++arg_idx, *bias_image); - CL_CHECK_FATAL(status); - } - status = kernel.setArg(++arg_idx, *out_image); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, strides[0]); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, offset); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_c_block); - CL_CHECK_FATAL(status); - - status = kernel.setArg(++arg_idx, dilations[0]); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, input_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, output_height); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_width); - CL_CHECK_FATAL(status); - status = kernel.setArg(++arg_idx, filter_height); - CL_CHECK_FATAL(status); - - auto global_work_size = - cl::NDRange{static_cast(default_work_size.data()[0]), - static_cast(default_work_size.data()[1]), - static_cast(default_work_size.data()[2])}; - - VLOG(4) << "out_image: " << out_image; - VLOG(4) << "global_work_size[3D]: {" << global_work_size[0] << "," - << global_work_size[1] << "," << global_work_size[2] << "}"; - - status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( - kernel, - cl::NullRange, - global_work_size, - cl::NullRange, - nullptr, - event_.get()); - CL_CHECK_FATAL(status); - context.cl_wait_list()->emplace(out_image, event_); - } - - private: - std::string kernel_func_name_{"depth_conv2d"}; - std::string build_options_{"-DCL_DTYPE_float"}; - std::shared_ptr event_{new cl::Event}; -}; } // namespace opencl } // namespace kernels } // namespace lite @@ -553,52 +139,3 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL))}) .Finalize(); - -REGISTER_LITE_KERNEL( - depthwise_conv2d, - kOpenCL, - kFP16, - kImageDefault, - paddle::lite::kernels::opencl::DepthwiseConv2dComputeFP16Image, - image2d) - .BindInput("Input", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), - DATALAYOUT(kImageDefault))}) - .BindInput("Bias", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), - DATALAYOUT(kImageDefault))}) - .BindInput("Filter", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), - DATALAYOUT(kImageNW))}) - .BindOutput("Output", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFP16), - DATALAYOUT(kImageDefault))}) - .Finalize(); -REGISTER_LITE_KERNEL( - depthwise_conv2d_basic, - kOpenCL, - kFloat, - kImageDefault, - paddle::lite::kernels::opencl::DepthwiseConv2dBasicComputeFP32Image, - image2d) - .BindInput("Input", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault))}) - .BindInput("Bias", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault))}) - .BindInput("Filter", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageNW))}) - .BindOutput("Output", - {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault))}) - .Finalize(); diff --git a/lite/kernels/opencl/depthwise_conv2d_compute_test.cc b/lite/kernels/opencl/depthwise_conv2d_compute_test.cc index c52aa87a73c8f9cbd91851c96162cde817f299b4..40cfdfffab452a004d45d804f62309dc71e0b0d9 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute_test.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute_test.cc @@ -177,135 +177,7 @@ TEST(depthwise_conv2d_buffer_fp32, compute) { TargetWrapperCL::Unmap(input_data, mapped_input); } -TEST(depthwise_conv2d_image2d_fp16, compute) { - LOG(INFO) << "to get kernel ..."; - auto kernels = KernelRegistry::Global().Create("depthwise_conv2d", - TARGET(kOpenCL), - PRECISION(kFP16), - DATALAYOUT(kImageDefault)); - ASSERT_FALSE(kernels.empty()); - - auto kernel = std::move(kernels.front()); - - LOG(INFO) << "get kernel"; - lite::Tensor input, filter, output; - operators::ConvParam param; - param.x = &input; - param.filter = &filter; - param.output = &output; - std::vector paddings = {0, 0}; - param.paddings = std::make_shared>(paddings); - param.strides = std::vector{1, 1}; - std::vector dilations = {1, 1}; - param.dilations = std::make_shared>(dilations); - - std::unique_ptr context(new KernelContext); - context->As().InitOnce(); - - kernel->SetParam(param); - std::unique_ptr dep_context(new KernelContext); - context->As().CopySharedTo( - &(dep_context->As())); - kernel->SetContext(std::move(dep_context)); - - LOG(INFO) << "kernel ready"; - std::default_random_engine engine; - std::uniform_real_distribution gen(-5, 5); - std::vector input_v(1 * 32 * 112 * 112); - std::vector filter_v(32 * 1 * 3 * 3); - for (auto& i : input_v) { - i = gen(engine); - } - for (auto& f : filter_v) { - f = gen(engine); - } - - LOG(INFO) << "prepare input"; - input.Resize({1, 32, 112, 112}); - CLImageConverterDefault* default_converter = new CLImageConverterDefault(); - DDim input_image_shape = - default_converter->InitImageDimInfoWith(input.dims()); - LOG(INFO) << "input_image_shape = " << input_image_shape[0] << " " - << input_image_shape[1]; - std::vector input_image_data(input_image_shape.production() * - 4); // 4 : RGBA - default_converter->NCHWToImage( - input_v.data(), input_image_data.data(), input.dims()); - auto* input_image = input.mutable_data( - input_image_shape[0], input_image_shape[1], input_image_data.data()); - - LOG(INFO) << "prepare kernel"; - filter.Resize({32, 1, 3, 3}); - CLImageConverterNWBlock* nw_converter = new CLImageConverterNWBlock(); - DDim filter_image_shape = nw_converter->InitImageDimInfoWith(filter.dims()); - LOG(INFO) << "filter_image_shape = " << filter_image_shape[0] << " " - << filter_image_shape[1]; - std::vector filter_image_data(filter_image_shape.production() * - 4); // 4 : RGBA - nw_converter->NCHWToImage( - filter_v.data(), filter_image_data.data(), filter.dims()); - auto* filter_image = filter.mutable_data( - filter_image_shape[0], filter_image_shape[1], filter_image_data.data()); - - LOG(INFO) << "launch"; - output.Resize({1, 32, 110, 110}); - DDim output_image_shape = - default_converter->InitImageDimInfoWith(output.dims()); - LOG(INFO) << "output_image_shape = " << output_image_shape[0] << " " - << output_image_shape[1]; - auto* output_image = output.mutable_data( - output_image_shape[0], output_image_shape[1]); - - kernel->Launch(); - - auto* wait_list = context->As().cl_wait_list(); - auto* out_ptr = param.output->data(); - auto it = wait_list->find(out_ptr); - if (it != wait_list->end()) { - VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; - LOG(INFO) << "--- Find the sync event for the target cl tensor. ---"; - auto& event = *(it->second); - event.wait(); - } else { - LOG(FATAL) << "Could not find the sync event for the target cl tensor."; - LOG(INFO) << "Could not find the sync event for the target cl tensor."; - } - - lite::Tensor output_ref; - output_ref.Resize({1, 32, 110, 110}); - auto* output_ref_data = output_ref.mutable_data(TARGET(kARM)); - depth_conv(input_v.data(), - input.dims(), - filter_v.data(), - filter.dims(), - output_ref_data, - output_ref.dims()); - - const size_t cl_image2d_row_pitch{0}; - const size_t cl_image2d_slice_pitch{0}; - - float* output_image_data = new float[output_image_shape.production() * 4]; - TargetWrapperCL::ImgcpySync(output_image_data, - output_image, - output_image_shape[0], - output_image_shape[1], - cl_image2d_row_pitch, - cl_image2d_slice_pitch, - IoDirection::DtoH); - - float* output_data = new float[output_image_shape.production() * 4]; - default_converter->ImageToNCHW( - output_image_data, output_data, output_image_shape, output.dims()); - - LOG(INFO) << "output_data vs output_ref_data"; - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4); - LOG(INFO) << output_data[i] << " " << output_ref_data[i]; - } -} - } // namespace lite } // namespace paddle USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFP16, kImageDefault, image2d); diff --git a/lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc b/lite/kernels/opencl/depthwise_conv2d_image2d_compute_test.cc similarity index 72% rename from lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc rename to lite/kernels/opencl/depthwise_conv2d_image2d_compute_test.cc index 96ee99e538cc2f293d1f97b2b70a678a0a8ef7b9..1b96ffe0502c3e2d654f88e9c9ac35d20704ca01 100644 --- a/lite/kernels/opencl/depthwise_conv2d_basic_compute_test.cc +++ b/lite/kernels/opencl/depthwise_conv2d_image2d_compute_test.cc @@ -142,7 +142,7 @@ TEST(depthwise_conv2d_basic, compute) { VLOG(4) << "to get kernel ..."; auto kernels = - KernelRegistry::Global().Create("depthwise_conv2d_basic", + KernelRegistry::Global().Create("depthwise_conv2d", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)); @@ -383,7 +383,133 @@ TEST(depthwise_conv2d_basic, compute) { #endif } +TEST(depthwise_conv2d_image2d_fp16, compute) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create("depthwise_conv2d", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + auto kernel = std::move(kernels.front()); + + LOG(INFO) << "get kernel"; + lite::Tensor input, filter, output; + operators::ConvParam param; + param.x = &input; + param.filter = &filter; + param.output = &output; + std::vector paddings = {0, 0}; + param.paddings = std::make_shared>(paddings); + param.strides = std::vector{1, 1}; + std::vector dilations = {1, 1}; + param.dilations = std::make_shared>(dilations); + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr dep_context(new KernelContext); + context->As().CopySharedTo( + &(dep_context->As())); + kernel->SetContext(std::move(dep_context)); + + LOG(INFO) << "kernel ready"; + std::default_random_engine engine; + std::uniform_real_distribution gen(-5, 5); + std::vector input_v(1 * 32 * 112 * 112); + std::vector filter_v(32 * 1 * 3 * 3); + for (auto& i : input_v) { + i = gen(engine); + } + for (auto& f : filter_v) { + f = gen(engine); + } + + LOG(INFO) << "prepare input"; + input.Resize({1, 32, 112, 112}); + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim input_image_shape = + default_converter->InitImageDimInfoWith(input.dims()); + LOG(INFO) << "input_image_shape = " << input_image_shape[0] << " " + << input_image_shape[1]; + std::vector input_image_data(input_image_shape.production() * + 4); // 4 : RGBA + default_converter->NCHWToImage( + input_v.data(), input_image_data.data(), input.dims()); + auto* input_image = input.mutable_data( + input_image_shape[0], input_image_shape[1], input_image_data.data()); + + LOG(INFO) << "prepare kernel"; + filter.Resize({32, 1, 3, 3}); + CLImageConverterNWBlock* nw_converter = new CLImageConverterNWBlock(); + DDim filter_image_shape = nw_converter->InitImageDimInfoWith(filter.dims()); + LOG(INFO) << "filter_image_shape = " << filter_image_shape[0] << " " + << filter_image_shape[1]; + std::vector filter_image_data(filter_image_shape.production() * + 4); // 4 : RGBA + nw_converter->NCHWToImage( + filter_v.data(), filter_image_data.data(), filter.dims()); + auto* filter_image = filter.mutable_data( + filter_image_shape[0], filter_image_shape[1], filter_image_data.data()); + + LOG(INFO) << "launch"; + output.Resize({1, 32, 110, 110}); + DDim output_image_shape = + default_converter->InitImageDimInfoWith(output.dims()); + LOG(INFO) << "output_image_shape = " << output_image_shape[0] << " " + << output_image_shape[1]; + auto* output_image = output.mutable_data( + output_image_shape[0], output_image_shape[1]); + + kernel->Launch(); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; + LOG(INFO) << "--- Find the sync event for the target cl tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target cl tensor."; + LOG(INFO) << "Could not find the sync event for the target cl tensor."; + } + + lite::Tensor output_ref; + output_ref.Resize({1, 32, 110, 110}); + auto* output_ref_data = output_ref.mutable_data(TARGET(kARM)); + depth_conv(input_v.data(), + input.dims(), + filter_v.data(), + filter.dims(), + output_ref_data, + output_ref.dims()); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + + float* output_image_data = new float[output_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(output_image_data, + output_image, + output_image_shape[0], + output_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + + float* output_data = new float[output_image_shape.production() * 4]; + default_converter->ImageToNCHW( + output_image_data, output_data, output_image_shape, output.dims()); + + LOG(INFO) << "output_data vs output_ref_data"; + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-4); + LOG(INFO) << output_data[i] << " " << output_ref_data[i]; + } +} + } // namespace lite } // namespace paddle -USE_LITE_KERNEL( - depthwise_conv2d_basic, kOpenCL, kFloat, kImageDefault, image2d); +USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kImageDefault, image2d); diff --git a/lite/kernels/opencl/fc_compute_test.cc b/lite/kernels/opencl/fc_compute_test.cc index 7f0c9c49a9920b10ceaa29cd1b548f59d5758f3b..863eab6297a88bcb2827c6ed09dfd1cecd7fae2d 100644 --- a/lite/kernels/opencl/fc_compute_test.cc +++ b/lite/kernels/opencl/fc_compute_test.cc @@ -66,6 +66,8 @@ void PrintData(std::string name, float* a, const int rows, const int cols) { } } +// buffer +#if 0 // fc_buffer // #define PRINT_RESULT #define LOOP_TEST TEST(fc, compute) { @@ -193,8 +195,9 @@ TEST(fc, compute) { } // m #endif } +#endif // fc_buffer } // namespace lite } // namespace paddle -USE_LITE_KERNEL(fc, kOpenCL, kFloat, kNCHW, def); +// USE_LITE_KERNEL(fc, kOpenCL, kFloat, kNCHW, def); diff --git a/lite/kernels/opencl/pool_compute.cc b/lite/kernels/opencl/pool_compute.cc index fca2cbe96d56b65e5f33acacff20c781b3400ed0..c0a00e87b8ad67ba0028ff4fa57f0811d52c1f0a 100644 --- a/lite/kernels/opencl/pool_compute.cc +++ b/lite/kernels/opencl/pool_compute.cc @@ -229,15 +229,15 @@ class PoolComputeImage2D : public KernelLite{3, 6, 10, 10}); @@ -414,7 +415,7 @@ TEST(sigmoid_image2d_fp16, compute) { } // namespace paddle // sigmoid buffer -USE_LITE_KERNEL(sigmoid, kOpenCL, kFloat, kNCHW, def); +// USE_LITE_KERNEL(sigmoid, kOpenCL, kFloat, kNCHW, def); // sigmoid image2d fp32 USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault);