From 2d4e9c13aad52fbbacde163de789780c8353889e Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Sun, 19 Jan 2020 12:54:26 +0800 Subject: [PATCH] =?UTF-8?q?[LITE][OPENCL]=20conv2d=5F1x1=5Fimage,=20choose?= =?UTF-8?q?=20simple=20kernel=20when=20in=20some=20ca=E2=80=A6=20(#2771)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [LITE][OPENCL] conv2d_1x1_image, choose simple kernel when in some case. for opencl ,test=develop * [LITE][OPENCL] conv2d_1x1_image, add looptest ,test=develop --- .../cl_kernel/image/conv2d_1x1_kernel.cl | 169 ++++++ lite/kernels/opencl/conv2d_1x1_compute.cc | 18 +- .../kernels/opencl/conv2d_1x1_compute_test.cc | 539 ++++++++++-------- 3 files changed, 471 insertions(+), 255 deletions(-) diff --git a/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl index 6fe5596a4c..2b037080b7 100644 --- a/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/conv2d_1x1_kernel.cl @@ -214,3 +214,172 @@ __kernel void conv2d_1x1(__private const int global_size_dim0, WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos3, output3); } } + +__kernel void conv2d_1x1_simple(__private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif +#ifdef BATCH_NORM +__read_only image2d_t new_scale, + __read_only image2d_t new_biase, +#endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int input_c_origin, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int old_w) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int out_w0 = out_w; + int out_w1 = out_w + global_size_dim1; + int out_w2 = out_w + global_size_dim1 * 2; + int out_w3 = out_w + global_size_dim1 * 3; + + int outpos_main = mul24(out_c, old_w); + int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); + int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); + int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); + int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 stride_xy = (int2)(stride, stride); + + int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); + int2 in_pos_in_one_block0 = + ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); + int2 in_pos_in_one_block1 = + ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); + int2 in_pos_in_one_block2 = + ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); + int2 in_pos_in_one_block3 = + ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, (int2)(out_c, 0)); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; +#elif defined(BIASE_ELE) + CL_DTYPE4 output0 = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, output_pos0); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; + +#else + CL_DTYPE4 output0 = 0.0f; + CL_DTYPE4 output1 = 0.0f; + CL_DTYPE4 output2 = 0.0f; + CL_DTYPE4 output3 = 0.0f; +#endif + + for (int i = 0; i < input_c; ++i) { + // ------------0--------------- + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, + in_pos_in_one_block0.y); + CL_DTYPE4 input0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + + CL_DTYPE4 weight0 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 0)); + CL_DTYPE4 weight1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 1)); + CL_DTYPE4 weight2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 2)); + CL_DTYPE4 weight3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, filter, sampler, (int2)(out_c, i * 4 + 3)); + + output0 = mad(input0.x, weight0, output0); + output0 = mad(input0.y, weight1, output0); + output0 = mad(input0.z, weight2, output0); + output0 = mad(input0.w, weight3, output0); + + pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, + in_pos_in_one_block1.y); + CL_DTYPE4 input1 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + output1 = mad(input1.x, weight0, output1); + output1 = mad(input1.y, weight1, output1); + output1 = mad(input1.z, weight2, output1); + output1 = mad(input1.w, weight3, output1); + + pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, + in_pos_in_one_block2.y); + CL_DTYPE4 input2 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + output2 = mad(input2.x, weight0, output2); + output2 = mad(input2.y, weight1, output2); + output2 = mad(input2.z, weight2, output2); + output2 = mad(input2.w, weight3, output2); + + pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, + in_pos_in_one_block3.y); + CL_DTYPE4 input3 = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, pos_in); + output3 = mad(input3.x, weight0, output3); + output3 = mad(input3.y, weight1, output3); + output3 = mad(input3.z, weight2, output3); + output3 = mad(input3.w, weight3, output3); + } + +#ifdef BATCH_NORM + output0 = output0 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output1 = output1 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output2 = output2 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); + + output3 = output3 * READ_IMG_TYPE( + CL_DTYPE_CHAR, new_scale, sampler, (int2)(out_c, 0)) + + READ_IMG_TYPE(CL_DTYPE_CHAR, new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output0 = activation_type4(output0); + output1 = activation_type4(output1); + output2 = activation_type4(output2); + output3 = activation_type4(output3); +#endif + + if (out_w0 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos0, output0); + } + + if (out_w1 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos1, output1); + } + + if (out_w2 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos2, output2); + } + + if (out_w3 < old_w) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos3, output3); + } +} diff --git a/lite/kernels/opencl/conv2d_1x1_compute.cc b/lite/kernels/opencl/conv2d_1x1_compute.cc index 3f313f542c..975105fd41 100644 --- a/lite/kernels/opencl/conv2d_1x1_compute.cc +++ b/lite/kernels/opencl/conv2d_1x1_compute.cc @@ -13,6 +13,7 @@ // limitations under the License. #include + #include "lite/backends/opencl/cl_include.h" #include "lite/core/kernel.h" #include "lite/core/op_registry.h" @@ -45,8 +46,14 @@ class Conv2d1x1Image2DCompute : public KernelLiteAs(); - context.cl_context()->AddKernel( - kernel_func_name_, "image/conv2d_1x1_kernel.cl", build_options_); + if (param.x->dims()[1] % 4 == 0) { + context.cl_context()->AddKernel(kernel_func_name_simple_, + "image/conv2d_1x1_kernel.cl", + build_options_); + } else { + context.cl_context()->AddKernel( + kernel_func_name_, "image/conv2d_1x1_kernel.cl", build_options_); + } } void Run() override { @@ -135,7 +142,11 @@ class Conv2d1x1Image2DCompute : public KernelLiteAs(); CHECK(context.cl_context() != nullptr); STL::stringstream kernel_key; - kernel_key << kernel_func_name_ << build_options_; + if (input_dims[1] % 4 == 0) { + kernel_key << kernel_func_name_simple_ << build_options_; + } else { + kernel_key << kernel_func_name_ << build_options_; + } auto kernel = context.cl_context()->GetKernel(kernel_key.str()); int maped_w = maptofactor(w, 4); @@ -215,6 +226,7 @@ class Conv2d1x1Image2DCompute : public KernelLite event_{new cl::Event}; }; diff --git a/lite/kernels/opencl/conv2d_1x1_compute_test.cc b/lite/kernels/opencl/conv2d_1x1_compute_test.cc index c35e734492..6879d2ba38 100644 --- a/lite/kernels/opencl/conv2d_1x1_compute_test.cc +++ b/lite/kernels/opencl/conv2d_1x1_compute_test.cc @@ -13,12 +13,14 @@ // limitations under the License. #include + #include + #include "lite/backends/opencl/cl_image_converter.h" #include "lite/backends/opencl/target_wrapper.h" - #include "lite/core/op_registry.h" #include "lite/core/tensor.h" +#include "lite/utils/logging.h" namespace paddle { namespace lite { @@ -106,7 +108,6 @@ static void conv_basic(const Dtype1* din, } } } - TEST(conv2d_1x1, compute) { // conv infos const int ksize = 1; @@ -114,262 +115,296 @@ TEST(conv2d_1x1, compute) { const int pad = 0; const int group = 1; const int dilation = 0; - // int loop_cnt = 0; - - const bool bias_flag = true; - const bool relu_flag = true; - const int batch_size = 8; - const int oc = 64; - const int ih = 28; - const int iw = 28; - const int ic = 63; - - const int oh = ih; - const int ow = iw; - - LOG(INFO) << "to get kernel ..."; - auto kernels = KernelRegistry::Global().Create("conv2d_1x1", - TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kImageDefault)); - ASSERT_FALSE(kernels.empty()); - - auto kernel = std::move(kernels.front()); - LOG(INFO) << "created conv2d_1x1 kernel"; - - LOG(INFO) << "prepare kernel ------"; - - lite::Tensor input, filter, bias, output; - operators::ConvParam param; - param.x = &input; - param.filter = &filter; - param.output = &output; - if (bias_flag) { - param.bias = &bias; - } - param.fuse_relu = relu_flag; - - std::vector paddings = {pad, pad, pad, pad}; - std::vector dilations = {dilation, dilation}; - - param.paddings = std::make_shared>(paddings); - param.dilations = std::make_shared>(dilations); - param.strides = std::vector{stride, stride}; - - std::unique_ptr context(new KernelContext); - context->As().InitOnce(); - - std::unique_ptr conv_1x1_context(new KernelContext); - context->As().CopySharedTo( - &(conv_1x1_context->As())); - kernel->SetContext(std::move(conv_1x1_context)); - - const DDim& input_dim = - lite::DDim{std::vector({batch_size, ic, ih, iw})}; - - const DDim& filter_dim = - lite::DDim{std::vector({oc, ic, ksize, ksize})}; - const DDim& out_dim = - lite::DDim{std::vector({batch_size, oc, ih, iw})}; - // element wise bias - const DDim& bias_dim = lite::DDim{std::vector({oc})}; - - param.x->Resize(input_dim); - param.filter->Resize(filter_dim); - param.output->Resize(out_dim); - if (bias_flag) { - param.bias->Resize(bias_dim); - } +// int loop_cnt = 0; + +#ifdef LOOP_TEST + for (int batch_size = 1; batch_size < 4; ++batch_size) { + for (int oc = 4; oc < 10; oc += 1) { // oc + for (int ih = 4; ih < 9; ih += 1) { // ih + /*int iw = ih;*/ for (int iw = 4; iw < 10; iw += 1) { // iw + for (int ic = 4; ic < 10; ic += 1) { // ic + for (bool bias_flag : {true, false}) { + for (bool relu_flag : {true, false}) { +#else + const int batch_size = 1; + const int oc = 4; + const int ih = 8; + const int iw = 8; + const int ic = 4; + const bool bias_flag = false; + const bool relu_flag = false; +#endif + const int oh = ih; + const int ow = iw; + + VLOG(4) << "to get kernel ..."; + auto kernels = + KernelRegistry::Global().Create("conv2d_1x1", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + + auto kernel = std::move(kernels.front()); + VLOG(4) << "created conv2d_1x1 kernel"; + + VLOG(4) << "prepare kernel ------"; + + lite::Tensor input, filter, bias, output; + operators::ConvParam param; + param.x = &input; + param.filter = &filter; + param.output = &output; + if (bias_flag) { + param.bias = &bias; + } + param.fuse_relu = relu_flag; + + std::vector paddings = {pad, pad, pad, pad}; + std::vector dilations = {dilation, dilation}; + + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); + param.strides = std::vector{stride, stride}; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + std::unique_ptr conv_1x1_context( + new KernelContext); + context->As().CopySharedTo( + &(conv_1x1_context->As())); + kernel->SetContext(std::move(conv_1x1_context)); + + const DDim& input_dim = + lite::DDim{std::vector({batch_size, ic, ih, iw})}; + + const DDim& filter_dim = + lite::DDim{std::vector({oc, ic, ksize, ksize})}; + const DDim& out_dim = + lite::DDim{std::vector({batch_size, oc, ih, iw})}; + // element wise bias + const DDim& bias_dim = lite::DDim{std::vector({oc})}; + + param.x->Resize(input_dim); + param.filter->Resize(filter_dim); + param.output->Resize(out_dim); + if (bias_flag) { + param.bias->Resize(bias_dim); + } - kernel->SetParam(param); - - size_t input_image_width = iw * ((ic + 3) / 4); - size_t input_image_height = ih * batch_size; - - size_t out_image_width = ow * ((oc + 3) / 4); - size_t out_image_height = oh * batch_size; - - size_t bias_image_width = ow * ((oc + 3) / 4); - size_t bias_image_height = oh * batch_size; - - size_t filter_image_width = ksize * ((oc + 3) / 4); - size_t filter_image_height = ic * ksize; - - auto* input_data = input.mutable_data(input_image_width, - input_image_height); - auto* filter_data = filter.mutable_data( - filter_image_width, filter_image_height); - bias.mutable_data(bias_image_width, bias_image_height); - auto* bias_data = bias.mutable_data(bias_image_width, - bias_image_height); - - const size_t cl_image2d_row_pitch{0}; - const size_t cl_image2d_slice_pitch{0}; - - LOG(INFO) << "map input ..."; - auto* mapped_input = - static_cast(TargetWrapperCL::MapImage(input_data, - input_image_width, - input_image_height, - cl_image2d_row_pitch, - cl_image2d_slice_pitch)); - - LOG(INFO) << "map filter ..."; - auto* mapped_filter = - static_cast(TargetWrapperCL::MapImage(filter_data, - filter_image_width, - filter_image_height, - cl_image2d_row_pitch, - cl_image2d_slice_pitch)); - - std::default_random_engine engine; - std::uniform_real_distribution gen(-5, 5); - std::vector input_v(batch_size * ic * ih * iw); - std::vector filter_v(oc * ic * ksize * ksize); - std::vector output_v(batch_size * oc * ih * iw); - std::vector bias_v(oc); - - float* input_v_data = &input_v[0]; - float* filter_v_data = &filter_v[0]; - float* output_v_data = &output_v[0]; - float* bias_v_data = &bias_v[0]; - - LOG(INFO) << "gen input and filter ..."; - - for (auto& i : input_v) { - i = gen(engine); - } - for (auto& f : filter_v) { - f = gen(engine); - } + kernel->SetParam(param); - LOG(INFO) << "after gen input and filter ..."; - LOG(INFO) << "input_v.size(): " << input_v.size(); - LOG(INFO) << "filter_v.size(): " << filter_v.size(); - LOG(INFO) << "output_v.size(): " << output_v.size(); - LOG(INFO) << "bias_v.size(): " << bias_v.size(); - LOG(INFO) << "input_dim.production(): " << input_dim.production(); - LOG(INFO) << "filter_dim.production(): " << filter_dim.production(); - LOG(INFO) << "out_dim.production(): " << out_dim.production(); - LOG(INFO) << "bias_dim.production(): " << bias_dim.production(); - LOG(INFO) << "4 * input_image_height * input_image_width: " - << 4 * input_image_height * input_image_width; - LOG(INFO) << "4 * filter_image_width * filter_image_height: " - << 4 * filter_image_width * filter_image_height; - - CHECK(input_dim.production() == input_v.size()); - CHECK_LE(input_dim.production(), 4 * input_image_height * input_image_width); - CHECK(filter_dim.production() == filter_v.size()); - CHECK_LE(filter_dim.production(), - 4 * filter_image_width * filter_image_height); - - paddle::lite::CLImageConverterDefault default_convertor; - LOG(INFO) << "set mapped input ..."; - default_convertor.NCHWToImage(input_v_data, mapped_input, input_dim); - LOG(INFO) << "set mapped filter ..."; - paddle::lite::CLImageConverterNWBlock nw_convertor; - nw_convertor.NCHWToImage(filter_v_data, mapped_filter, filter_dim); - - LOG(INFO) << "resize output ..."; - output.Resize(out_dim); - - // cpu conv basic calc - lite::Tensor out_ref; - out_ref.Resize(out_dim); - - float* mapped_bias = nullptr; - if (bias_flag) { - mapped_bias = - static_cast(TargetWrapperCL::MapImage(bias_data, - bias_image_width, - bias_image_height, - cl_image2d_row_pitch, - cl_image2d_slice_pitch)); - - for (int i = 0; i < bias_dim.production(); ++i) { - bias_v[i] = static_cast(gen(engine)); - } - CLImageConverterFolder folder_convertor; - folder_convertor.NCHWToImage(bias_v_data, mapped_bias, bias_dim); - } - LOG(INFO) << "prepare kernel ready"; - - LOG(INFO) << "kernel launch ..."; - kernel->Launch(); - LOG(INFO) << "mutable output ..."; - auto* output_data = output.mutable_data(out_image_width, - out_image_height); - - auto* wait_list = context->As().cl_wait_list(); - auto* out_ptr = param.output->data(); - auto it = wait_list->find(out_ptr); - - if (it != wait_list->end()) { - VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; - auto& event = *(it->second); - event.wait(); - } else { - LOG(FATAL) << "Could not find the sync event for the target cl tensor."; - } + size_t input_image_width = iw * ((ic + 3) / 4); + size_t input_image_height = ih * batch_size; - auto* mapped_output = - static_cast(TargetWrapperCL::MapImage(output_data, - out_image_width, - out_image_height, - cl_image2d_row_pitch, - cl_image2d_slice_pitch)); - LOG(INFO) << "mutable_data out_ref_data: "; - - // run cpu ref - auto* out_ref_data = out_ref.mutable_data(TARGET(kARM)); - - LOG(INFO) << " conv_basic beigin ..... "; - - conv_basic(input_v_data, - out_ref_data, - batch_size, - oc, - oh, - ow, - ic, - ih, - iw, - filter_v_data, - bias_v_data, // mapped_bias, - group, - ksize, - ksize, - stride, - stride, - dilation, - dilation, - pad, - pad, - bias_flag, - relu_flag); - LOG(INFO) << " conv_basic end ..... "; - - LOG(INFO) << " out_dim: " << out_dim; - const DDim& out_image_dims = lite::DDim{ - std::vector({static_cast(out_image_width), - static_cast(out_image_height)})}; - default_convertor.ImageToNCHW( - mapped_output, output_v_data, out_image_dims, out_dim); - for (int i = 0; i < out_dim.production(); i++) { - EXPECT_NEAR(output_v_data[i], out_ref_data[i], 1e-3); - if (abs(output_v_data[i] - out_ref_data[i]) > 1e-3) { - LOG(FATAL) << "error idx:" << i; - } - } + size_t out_image_width = ow * ((oc + 3) / 4); + size_t out_image_height = oh * batch_size; + + size_t bias_image_width = ow * ((oc + 3) / 4); + size_t bias_image_height = oh * batch_size; + + size_t filter_image_width = ksize * ((oc + 3) / 4); + size_t filter_image_height = ic * ksize; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; - TargetWrapperCL::Unmap(output_data, mapped_output); - TargetWrapperCL::Unmap(filter_data, mapped_filter); - TargetWrapperCL::Unmap(input_data, mapped_input); - if (bias_flag) { - if (mapped_bias) { - TargetWrapperCL::Unmap(bias_data, mapped_bias); + std::default_random_engine engine; + std::uniform_real_distribution gen(-5, 5); + + std::vector input_v(batch_size * ic * ih * iw); + std::vector filter_v(oc * ic * ksize * ksize); + std::vector output_v(batch_size * oc * ih * iw); + std::vector bias_v(oc); + + VLOG(4) << "gen input and filter ..."; + + for (auto& i : input_v) { + i = gen(engine); + } + for (auto& f : filter_v) { + f = gen(engine); + } + + VLOG(4) << "after gen input and filter ..."; + VLOG(4) << "input_v.size(): " << input_v.size(); + VLOG(4) << "filter_v.size(): " << filter_v.size(); + VLOG(4) << "output_v.size(): " << output_v.size(); + VLOG(4) << "bias_v.size(): " << bias_v.size(); + VLOG(4) << "input_dim.production(): " << input_dim.production(); + VLOG(4) << "filter_dim.production(): " + << filter_dim.production(); + VLOG(4) << "out_dim.production(): " << out_dim.production(); + VLOG(4) << "bias_dim.production(): " << bias_dim.production(); + VLOG(4) << "4 * input_image_height * input_image_width: " + << 4 * input_image_height * input_image_width; + VLOG(4) << "4 * filter_image_width * filter_image_height: " + << 4 * filter_image_width * filter_image_height; + + CHECK(input_dim.production() == input_v.size()); + CHECK_LE(input_dim.production(), + 4 * input_image_height * input_image_width); + CHECK(filter_dim.production() == filter_v.size()); + CHECK_LE(filter_dim.production(), + 4 * filter_image_width * filter_image_height); + + paddle::lite::CLImageConverterDefault default_convertor; + VLOG(4) << "set mapped input ..."; + std::vector x_image_v( + input_image_width * input_image_height * 4); // 4 : RGBA + std::vector filter_image_v( + filter_image_width * filter_image_height * 4); // 4 : RGBA + std::vector bias_image_v( + bias_image_width * bias_image_height * 4); // 4 : RGBA + std::vector out_image_v( + out_image_width * out_image_height * 4); // 4 : RGBA + + default_convertor.NCHWToImage( + input_v.data(), x_image_v.data(), input_dim); + + /* for (int j = 0; j < input_v.size(); j += 1) { + // VLOG(4) << "input_v + input[" << j << "]: + // " << input_v.data()[j]; + std::cout << j << " " << input_v.data()[j] << + std::endl; + } + std::cout << std::endl; + + for (int j = 0; j < x_image_v.size(); j += 1) { + // VLOG(4) << "x_image_v + input[" << j << + // "]: " << + x_image_v.data()[j]; + std::cout << j << " " << x_image_v.data()[j] + << std::endl; + }*/ + + VLOG(4) << "set mapped filter ..."; + paddle::lite::CLImageConverterNWBlock nw_convertor; + nw_convertor.NCHWToImage( + filter_v.data(), filter_image_v.data(), filter_dim); + + auto* input_image2d = input.mutable_data( + input_image_width, input_image_height, x_image_v.data()); + auto* filter_image2d = filter.mutable_data( + filter_image_width, + filter_image_height, + filter_image_v.data()); + + if (bias_flag) { + nw_convertor.NCHWToImage( + filter_v.data(), filter_image_v.data(), filter_dim); + + for (int i = 0; i < bias_dim.production(); ++i) { + bias_v[i] = static_cast(gen(engine)); + } + CLImageConverterFolder folder_convertor; + folder_convertor.NCHWToImage( + bias_v.data(), bias_image_v.data(), bias_dim); + auto* bias_data = bias.mutable_data( + bias_image_width, bias_image_height, bias_image_v.data()); + } + + VLOG(4) << "resize output ..."; + output.Resize(out_dim); + + // cpu conv basic calc + lite::Tensor out_ref; + out_ref.Resize(out_dim); + + VLOG(4) << "prepare kernel ready"; + + VLOG(4) << "kernel launch ..."; + kernel->Launch(); + VLOG(4) << "mutable output ..."; + auto* output_image2d = output.mutable_data( + out_image_width, out_image_height); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target " + "cl tensor."; + } + + TargetWrapperCL::ImgcpySync(out_image_v.data(), + output.data(), + out_image_width, + out_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + + DDim out_image_shape = + default_convertor.InitImageDimInfoWith(output.dims()); + + default_convertor.ImageToNCHW(out_image_v.data(), + output_v.data(), + out_image_shape, + output.dims()); + VLOG(4) << "mutable_data out_ref_data: "; + + // run cpu ref + auto* out_ref_data = out_ref.mutable_data(TARGET(kARM)); + + VLOG(4) << " conv_basic beigin ..... "; + + conv_basic(input_v.data(), + out_ref_data, + batch_size, + oc, + oh, + ow, + ic, + ih, + iw, + filter_v.data(), + bias_v.data(), // mapped_bias, + group, + ksize, + ksize, + stride, + stride, + dilation, + dilation, + pad, + pad, + bias_flag, + relu_flag); + VLOG(4) << " conv_basic end ..... "; + + VLOG(4) << " out_dim: " << out_dim; + const DDim& out_image_dims = lite::DDim{std::vector( + {static_cast(out_image_width), + static_cast(out_image_height)})}; + + for (int i = 0; i < out_dim.production(); i++) { + EXPECT_NEAR(output_v[i], out_ref_data[i], 1e-2); + if (abs(output_v[i] - out_ref_data[i]) > 1e-2) { + LOG(FATAL) << "error idx:" << i; + } + } + +#ifdef LOOP_TEST + } + } + } + } + } } } +#else +// nothing to do. +#endif } } // namespace lite -- GitLab