diff --git a/lite/backends/opencl/cl_image_converter.h b/lite/backends/opencl/cl_image_converter.h index 6faa8045576f06d8c636372de644e6b5c164a5f4..e318a0b86a88b3d42b8291c46e8b17a9f1128db4 100644 --- a/lite/backends/opencl/cl_image_converter.h +++ b/lite/backends/opencl/cl_image_converter.h @@ -103,6 +103,7 @@ class CLImageConverterNormal : public CLImageConverterBase { }; class CLImageConverterNWBlock : public CLImageConverterBase { + public: DDim InitImageDimInfoWith(const DDim &tensor_dim) override; void NCHWToImage(float *tensor, float *image, diff --git a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl index c9c16581d67db0c9143e91e13249edfd5901ddb8..532f947dd342b1ee4db69a084111a97ec014237f 100644 --- a/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl +++ b/lite/backends/opencl/cl_kernel/buffer/layout_kernel.cl @@ -61,6 +61,57 @@ __kernel void buffer_to_image2d(__global CL_DTYPE *in, write_imagef(output_image, output_pos, output); } +// buffer -> image2d_nw +__kernel void buffer_to_image2d_nw(__global CL_DTYPE* in, + __write_only image2d_t output_image, + __private const int out_H, + __private const int out_W, + __private const int out_N, + __private const int Stride0, + __private const int Stride1, + __private const int Stride2) { + const int out_n = get_global_id(0); + const int out_w = get_global_id(1); + const int out_ch = get_global_id(2); + + const int out_c = out_ch / out_H; + const int out_h = out_ch % out_H; + + const int in_c = out_c; // index of c in h direction + + const int in_n0 = out_n * 4 + 0; + const int in_n1 = out_n * 4 + 1; + const int in_n2 = out_n * 4 + 2; + const int in_n3 = out_n * 4 + 3; + + const int in_h = out_h; + const int in_w = out_w; + + int input_pos0 = in_n0 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos1 = in_n1 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos2 = in_n2 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + int input_pos3 = in_n3 * Stride2 + in_c * Stride1 + in_h * Stride0 + in_w; + + int2 output_pos; + output_pos.x = out_n * out_W + out_w; + output_pos.y = out_ch; + + CL_DTYPE4 output = (CL_DTYPE4)0.0f; + output.x = convert_float(in[input_pos0]); + if (out_N - 4 * out_n >= 2) { + output.y = convert_float(in[input_pos1]); + } + if (out_N - 4 * out_n >= 3) { + output.z = convert_float(in[input_pos2]); + } + if (out_N - 4 * out_n >= 4) { + output.w = convert_float(in[input_pos3]); + } + write_imagef(output_image, output_pos, output); +} + + + // image2d -> buffer __kernel void image2d_to_buffer(__read_only image2d_t input, __private const int in_width, diff --git a/lite/backends/opencl/cl_kernel/cl_common.h b/lite/backends/opencl/cl_kernel/cl_common.h index 7f901fc994ffd82ccfe99f59614a3422260d0dc5..815409eefdaa858dc2e2ddcc3efb8ebbf0d73ad2 100644 --- a/lite/backends/opencl/cl_kernel/cl_common.h +++ b/lite/backends/opencl/cl_kernel/cl_common.h @@ -61,3 +61,19 @@ inline CL_DTYPE activation(CL_DTYPE in #endif return output; } +inline CL_DTYPE4 activation_type4(CL_DTYPE4 in +#ifdef PRELU + , + CL_DTYPE4 prelu_alpha +#endif + ) { + CL_DTYPE4 output; +#ifdef PRELU + output = select(prelu_alpha * in, in, in >= (CL_DTYPE4)0.0); +#endif + +#ifdef RELU + output = fmax(in, (CL_DTYPE4)0); +#endif + return output; +} diff --git a/lite/backends/opencl/cl_kernel/image/conv_1x1_kernel.cl b/lite/backends/opencl/cl_kernel/image/conv_1x1_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..361d55531e2157c5070eaf50e302265933a93f73 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/conv_1x1_kernel.cl @@ -0,0 +1,203 @@ +#include + +__kernel void conv_1x1( + __private const int global_size_dim0, + __private const int global_size_dim1, + __private const int global_size_dim2, + __read_only image2d_t input_image, + __read_only image2d_t filter, +#if defined(BIASE_CH) || defined(BIASE_ELE) + __read_only image2d_t bias, +#endif + #ifdef BATCH_NORM + __read_only image2d_t new_scale, __read_only image2d_t new_biase, + #endif + __write_only image2d_t output_image, + __private const int stride, + __private const int offset, + __private const int input_c, + __private const int input_c_origin, + __private const int dilation, + __private const int input_width, /* of one block */ + __private const int input_height, /* of one block */ + __private const int output_width, + __private const int output_height, + __private const int old_w) { + CL_DTYPE zero = 0.0f; + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + int out_w0 = out_w; + int out_w1 = out_w + global_size_dim1; + int out_w2 = out_w + global_size_dim1 * 2; + int out_w3 = out_w + global_size_dim1 * 3; + + int outpos_main = mul24(out_c, old_w); + int2 output_pos0 = (int2)(outpos_main + out_w0, out_nh); + int2 output_pos1 = (int2)(outpos_main + out_w1, out_nh); + int2 output_pos2 = (int2)(outpos_main + out_w2, out_nh); + int2 output_pos3 = (int2)(outpos_main + out_w3, out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 stride_xy = (int2)(stride, stride); + + int2 ouput_pos_in_one_block0 = (int2)(out_w0, out_nh); + int2 in_pos_in_one_block0 = + ouput_pos_in_one_block0 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block1 = (int2)(out_w1, out_nh); + int2 in_pos_in_one_block1 = + ouput_pos_in_one_block1 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block2 = (int2)(out_w2, out_nh); + int2 in_pos_in_one_block2 = + ouput_pos_in_one_block2 * stride_xy + (int2)(offset, offset); + + int2 ouput_pos_in_one_block3 = (int2)(out_w3, out_nh); + int2 in_pos_in_one_block3 = + ouput_pos_in_one_block3 * stride_xy + (int2)(offset, offset); + +#ifdef BIASE_CH + CL_DTYPE4 output0 = read_imagef(bias, sampler, (int2)(out_c, 0)); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; +#elif defined(BIASE_ELE) + CL_DTYPE4 output0 = read_imagef(bias, sampler, output_pos0); + CL_DTYPE4 output1 = output0; + CL_DTYPE4 output2 = output0; + CL_DTYPE4 output3 = output0; + +#else + CL_DTYPE4 output0 = 0.0f; + CL_DTYPE4 output1 = 0.0f; + CL_DTYPE4 output2 = 0.0f; + CL_DTYPE4 output3 = 0.0f; +#endif + + int max_w_bound = input_c * input_width; + int burndary_index = input_c * 4 - input_c_origin; + bool burndary_index_w = + burndary_index == 1 || burndary_index == 2 || burndary_index == 3; + bool burndary_index_z = burndary_index == 2 || burndary_index == 3; + bool burndary_index_y = burndary_index == 3; + + for (int i = 0; i < input_c; ++i) { + // ------------0--------------- + int2 pos_in = (int2)(i * input_width + in_pos_in_one_block0.x, + in_pos_in_one_block0.y); + CL_DTYPE4 input0 = read_imagef(input_image, sampler, pos_in); + + CL_DTYPE4 weight0 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 0)); + CL_DTYPE4 weight1 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 1)); + CL_DTYPE4 weight2 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 2)); + CL_DTYPE4 weight3 = read_imagef(filter, sampler, (int2)(out_c, i * 4 + 3)); + int bound_gap = max_w_bound - pos_in.x - 1; + + bool outof_bound = bound_gap < input_width && bound_gap >= 0; + input0.w = select(input0.w, zero, outof_bound && burndary_index_w); + input0.z = select(input0.z, zero, outof_bound && burndary_index_z); + input0.y = select(input0.y, zero, outof_bound && burndary_index_y); + + output0 = mad(input0.x, weight0, output0); + output0 = mad(input0.y, weight1, output0); + output0 = mad(input0.z, weight2, output0); + output0 = mad(input0.w, weight3, output0); + // -------------1-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block1.x, + in_pos_in_one_block1.y); + CL_DTYPE4 input1 = read_imagef(input_image, sampler, pos_in); + + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input1.w = select(input1.w, zero, outof_bound && burndary_index_w); + input1.z = select(input1.z, zero, outof_bound && burndary_index_z); + input1.y = select(input1.y, zero, outof_bound && burndary_index_y); + + output1 = mad(input1.x, weight0, output1); + output1 = mad(input1.y, weight1, output1); + output1 = mad(input1.z, weight2, output1); + output1 = mad(input1.w, weight3, output1); + + // -------------2-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block2.x, + in_pos_in_one_block2.y); + CL_DTYPE4 input2 = read_imagef(input_image, sampler, pos_in); + + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input2.w = select(input2.w, zero, outof_bound && burndary_index_w); + input2.z = select(input2.z, zero, outof_bound && burndary_index_z); + input2.y = select(input2.y, zero, outof_bound && burndary_index_y); + + output2 = mad(input2.x, weight0, output2); + output2 = mad(input2.y, weight1, output2); + output2 = mad(input2.z, weight2, output2); + output2 = mad(input2.w, weight3, output2); + + // -------------3-------------- + pos_in = (int2)(i * input_width + in_pos_in_one_block3.x, + in_pos_in_one_block3.y); + CL_DTYPE4 input3 = read_imagef(input_image, sampler, pos_in); + bound_gap = max_w_bound - pos_in.x - 1; + + outof_bound = bound_gap < input_width && bound_gap >= 0; + input3.w = + select(input3.w, + zero, + outof_bound && (burndary_index == 1 || burndary_index == 2 || + burndary_index == 3)); + input3.z = + select(input3.z, + zero, + outof_bound && (burndary_index == 2 || burndary_index == 3)); + input3.y = select(input3.y, zero, outof_bound && burndary_index == 3); + + output3 = mad(input3.x, weight0, output3); + output3 = mad(input3.y, weight1, output3); + output3 = mad(input3.z, weight2, output3); + output3 = mad(input3.w, weight3, output3); + } + +#ifdef BATCH_NORM + output0 = output0 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) + + read_imagef(new_biase, sampler, (int2)(out_c, 0)); + + output1 = output1 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) + + read_imagef(new_biase, sampler, (int2)(out_c, 0)); + + output2 = output2 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) + + read_imagef(new_biase, sampler, (int2)(out_c, 0)); + + output3 = output3 * read_imagef(new_scale, sampler, (int2)(out_c, 0)) + + read_imagef(new_biase, sampler, (int2)(out_c, 0)); +#endif + +#ifdef RELU + output0 = activation_type4(output0); + output1 = activation_type4(output1); + output2 = activation_type4(output2); + output3 = activation_type4(output3); +#endif + + if (out_w0 < old_w) { + write_imagef(output_image, output_pos0, output0); + } + + if (out_w1 < old_w) { + write_imagef(output_image, output_pos1, output1); + } + + if (out_w2 < old_w) { + write_imagef(output_image, output_pos2, output2); + } + + if (out_w3 < old_w) { + write_imagef(output_image, output_pos3, output3); + } +} \ No newline at end of file diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 99b23c19f0f5870102782f0b4d639f6103257c31..ebdd51259719c37376f2449e49c812c52b8b01eb 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -14,6 +14,7 @@ add_kernel(pool_opencl OPENCL basic SRCS pool_compute.cc DEPS ${cl_kernel_deps}) add_kernel(io_copy_compute_opencl OPENCL basic SRCS io_copy_compute.cc DEPS ${tensor_lite} ${cl_kernel_deps}) add_kernel(relu_opencl OPENCL basic SRCS relu_compute.cc DEPS ${cl_kernel_deps}) add_kernel(depthwise_conv2d_opencl OPENCL basic SRCS depthwise_conv2d_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(conv2d_1x1_opencl OPENCL basic SRCS conv2d_1x1_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_compute.cc DEPS ${cl_kernel_deps}) add_kernel(layout_opencl OPENCL basic SRCS layout_compute.cc DEPS ${cl_kernel_deps}) @@ -47,10 +48,14 @@ lite_cc_test(test_depthwise_conv2d_opencl SRCS depthwise_conv2d_compute_test.cc DEPS depthwise_conv2d_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +lite_cc_test(test_conv2d_1x1_opencl SRCS conv2d_1x1_compute_test.cc + DEPS conv2d_1x1_opencl cl_image_converter op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) + lite_cc_test(test_conv_opencl SRCS conv_compute_test.cc DEPS conv_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_layout_opencl SRCS layout_compute_test.cc - DEPS layout_opencl op_registry program context + DEPS layout_opencl op_registry program context cl_image_converter ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) diff --git a/lite/kernels/opencl/conv2d_1x1_compute.cc b/lite/kernels/opencl/conv2d_1x1_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..e02b37b09610b99298ce505df19c922d65763f2f --- /dev/null +++ b/lite/kernels/opencl/conv2d_1x1_compute.cc @@ -0,0 +1,294 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +#define USE_BUFFER_FOR_CONV1x1_BIAS +class Conv2d1x1Image2DCompute + : public KernelLite { + public: + using param_t = operators::ConvParam; + + void PrepareForRun() override { + LOG(INFO) << "PrepareForRun ..."; + const auto& param = *param_.get_mutable(); + if (param.fuse_relu) { + build_options_ += " -DRELU"; + } + + const bool has_bias = param.bias != nullptr; + const bool is_element_wise_bias = + has_bias && param.output->dims() == param.bias->dims(); + if (has_bias) { + build_options_ += is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH"; + } + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/conv_1x1_kernel.cl", build_options_); + LOG(INFO) << "PrepareForRun Ready"; + } + + void Run() override { + LOG(INFO) << "opencl conv 1x1 run begin ..."; + LOG(INFO) << "param_.get_mutable ..."; + const auto& param = *param_.get_mutable(); + + LOG(INFO) << "get param dims ..."; + auto input_dims = param.x->dims(); + CHECK_GE(input_dims.size(), 4); + LOG(INFO) << "input_dims: " << input_dims; + + int input_width = input_dims[3]; + int input_height = input_dims[2]; + + auto filter_dims = param.filter->dims(); + + LOG(INFO) << "filter_dims: " << filter_dims; + + auto output_dims = param.output->dims(); + + LOG(INFO) << "output_dims: " << output_dims; + + int output_width = output_dims[3]; + int output_height = output_dims[2]; + // mute output image + auto out_image_shape = InitImageDimInfoWith(output_dims); + + LOG(INFO) << "out_image_shape: " << out_image_shape["width"] << ", " + << out_image_shape["height"]; + + auto* out_image = param.output->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + + // gen default_work_size + + const std::vector& default_work_size = + DefaultWorkSize(output_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + + int c_block = default_work_size[0]; + int w = default_work_size[1]; + int nh = default_work_size[2]; + LOG(INFO) << "default work size: " + << "{" << c_block << ", " << w << ", " << nh << "" + << "}"; + + auto paddings = *param.paddings; + LOG(INFO) << "paddings: " << paddings[0] << "," << paddings[1]; + + auto strides = param.strides; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + auto* input_image = param.x->data(); + auto* filter_image = param.filter->data(); + + // handle bias use buffer for channel wise , use image for element wise + const bool has_bias = param.bias != nullptr; + const bool is_element_wise_bias = + has_bias && param.output->dims() == param.bias->dims(); + + LOG(INFO) << "has bias: " << has_bias; + LOG(INFO) << "is_element_wise_bias : " << is_element_wise_bias; + LOG(INFO) << "get kernel ..."; + const cl::Buffer* bias_buf = nullptr; + const cl::Image2D* bias_image = nullptr; + if (has_bias) { +#ifndef USE_BUFFER_FOR_CONV1x1_BIAS + is_element_wise_bias + ? (bias_image = param.bias->data()) + : (bias_buf = param.bias->data()); +#else + bias_image = param.bias->data(); +#endif + } + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + + LOG(INFO) << "kernel_key: " << kernel_key.str(); + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + LOG(INFO) << "kernel ready ... " << kernel_key.str(); + cl_int status; + auto numel = output_dims.production(); + int arg_idx = 0; + + status = kernel.setArg(arg_idx, c_block); + CL_CHECK_FATAL(status); + + int maped_w = maptofactor(w, 4); + LOG(INFO) << "maped_w: " << maped_w; + + status = kernel.setArg(++arg_idx, maped_w); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, nh); + LOG(INFO) << "nh: " << nh; + + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, *input_image); + LOG(INFO) << "input_image: "; + + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, *filter_image); + LOG(INFO) << "filter_image: "; + + CL_CHECK_FATAL(status); + + if (has_bias) { +#ifndef USE_BUFFER_FOR_CONV1x1_BIAS + if (is_element_wise_bias != 0) { + LOG(INFO) << "set bias_image: "; + status = kernel.setArg(++arg_idx, *bias_image); + } else { + LOG(INFO) << "set bias_buf: "; + status = kernel.setArg(++arg_idx, *bias_buf); + } +#else + status = kernel.setArg(++arg_idx, *bias_image); +#endif + CL_CHECK_FATAL(status); + } + + status = kernel.setArg(++arg_idx, *out_image); + LOG(INFO) << "out_image: "; + + CL_CHECK_FATAL(status); + + CHECK_GE(strides.size(), 2); + CHECK(strides[0] == strides[1]); + status = kernel.setArg(++arg_idx, strides[0]); + LOG(INFO) << "strides: " << strides[0] << "," << strides[1]; + + CL_CHECK_FATAL(status); + + CHECK_GE(paddings.size(), 2); + CHECK(paddings[0] == paddings[1]); + + int offset = static_cast(param.filter->dims()[2]) / 2 - + static_cast(paddings[0]); + LOG(INFO) << "offset: " << offset; + + status = kernel.setArg(++arg_idx, offset); + CL_CHECK_FATAL(status); + + // calc input_c_block + auto input_image_shape = InitImageDimInfoWith(input_dims); + LOG(INFO) << "input_image_shape: " << input_image_shape["width"] << "," + << input_image_shape["height"]; + + int input_c_block = input_image_shape["width"] / input_dims[3]; + LOG(INFO) << "input_c_block: " << input_c_block; + + status = kernel.setArg(++arg_idx, input_c_block); + CL_CHECK_FATAL(status); + + int input_c = input_dims[1]; + LOG(INFO) << "input_c: " << input_c; + + status = kernel.setArg(++arg_idx, input_c); + CL_CHECK_FATAL(status); + + auto dilations = *param.dilations; + LOG(INFO) << "dilations.size : " << dilations.size(); + LOG(INFO) << "dilations: " << dilations[0] << ", " << dilations[1]; + + CHECK_GE(dilations.size(), 2); + CHECK(dilations[0] == dilations[1]); + status = kernel.setArg(++arg_idx, dilations[0]); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, input_width); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, input_height); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, output_width); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, output_height); + CL_CHECK_FATAL(status); + + status = kernel.setArg(++arg_idx, w); + CL_CHECK_FATAL(status); + + // clac gloabl_work_size + auto global_work_size = + cl::NDRange{static_cast(default_work_size.data()[0]), + static_cast(maped_w), + static_cast(default_work_size.data()[2])}; + + LOG(INFO) << "global_work_size :" << global_work_size; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_image, event_); + } + + private: + std::string kernel_func_name_{"conv_1x1"}; + std::string build_options_{"-DCL_DTYPE=float "}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(conv2d_1x1, + kOpenCL, + kFloat, + kNHWC, + paddle::lite::kernels::opencl::Conv2d1x1Image2DCompute, + image2d) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Bias", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .BindInput("Filter", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageNW))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/conv2d_1x1_compute_test.cc b/lite/kernels/opencl/conv2d_1x1_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..591e9ad795b96c832a5b169570b4773646276695 --- /dev/null +++ b/lite/kernels/opencl/conv2d_1x1_compute_test.cc @@ -0,0 +1,376 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/target_wrapper.h" + +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +template +static void conv_basic(const Dtype1* din, + Dtype2* dout, + int num, + int chout, + int hout, + int wout, + int chin, + int hin, + int win, + const Dtype1* weights, + const Dtype2* bias, + int group, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int dila_w, + int dila_h, + int pad_w, + int pad_h, + bool flag_bias, + bool flag_relu) { + Dtype2 beta = 0; + auto src_data = din; + auto dst_data_ref = dout; + auto weights_data = weights; + auto with_bias = flag_bias; + auto bias_data = bias; + + int in_num = num; + int out_channels = chout; + int out_h = hout; + int out_w = wout; + + int in_channel = chin; + int in_h = hin; + int in_w = win; + int out_c_group = out_channels / group; + int in_c_group = in_channel / group; + + for (int n = 0; n < in_num; ++n) { + for (int g = 0; g < group; ++g) { + for (int oc = 0; oc < out_c_group; ++oc) { + for (int oh = 0; oh < out_h; ++oh) { + for (int ow = 0; ow < out_w; ++ow) { + int out_idx = n * group * out_c_group * out_h * out_w + + g * out_c_group * out_h * out_w + oc * out_h * out_w + + oh * out_w + ow; + Dtype2 bias_d = + with_bias ? (bias_data[g * out_c_group + oc]) : (Dtype2)0; + dst_data_ref[out_idx] = bias_d; // + dst_data_ref[out_idx] * beta; + for (int ic = 0; ic < in_c_group; ++ic) { + for (int kh = 0; kh < kernel_h; ++kh) { + for (int kw = 0; kw < kernel_w; ++kw) { + int iw = ow * stride_w - pad_w + kw * (dila_w); + int ih = oh * stride_h - pad_h + kh * (dila_h); + if (iw < 0 || iw >= in_w) continue; + if (ih < 0 || ih >= in_h) continue; + + int iidx = n * in_channel * in_h * in_w + + g * in_c_group * in_h * in_w + ic * in_h * in_w + + ih * in_w + iw; + int widx = + g * out_c_group * in_c_group * kernel_h * kernel_w + + oc * in_c_group * kernel_h * kernel_w + + ic * kernel_h * kernel_w + kh * kernel_w + kw; + + dst_data_ref[out_idx] += src_data[iidx] * weights_data[widx]; + } + } + } + if (flag_relu) { + dst_data_ref[out_idx] = dst_data_ref[out_idx] > (Dtype2)0 + ? dst_data_ref[out_idx] + : (Dtype2)0; + } + } + } + } + } + } +} + +TEST(conv2d_1x1, compute) { + // conv infos + const int ksize = 1; + const int stride = 1; + const int pad = 0; + const int group = 1; + const int dilation = 0; + // int loop_cnt = 0; + + const bool bias_flag = true; + const bool relu_flag = true; + const int batch_size = 8; + const int oc = 64; + const int ih = 28; + const int iw = 28; + const int ic = 63; + + const int oh = ih; + const int ow = iw; + + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "conv2d_1x1", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)); + ASSERT_FALSE(kernels.empty()); + + auto kernel = std::move(kernels.front()); + LOG(INFO) << "created conv2d_1x1 kernel"; + + LOG(INFO) << "prepare kernel ------"; + + lite::Tensor input, filter, bias, output; + operators::ConvParam param; + param.x = &input; + param.filter = &filter; + param.output = &output; + if (bias_flag) { + param.bias = &bias; + } + param.fuse_relu = relu_flag; + + std::vector paddings = {pad, pad, pad, pad}; + std::vector dilations = {dilation, dilation}; + + param.paddings = std::make_shared>(paddings); + param.dilations = std::make_shared>(dilations); + param.strides = std::vector{stride, stride}; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + std::unique_ptr conv_1x1_context(new KernelContext); + context->As().CopySharedTo( + &(conv_1x1_context->As())); + kernel->SetContext(std::move(conv_1x1_context)); + + const DDim& input_dim = + lite::DDim{std::vector({batch_size, ic, ih, iw})}; + + const DDim& filter_dim = + lite::DDim{std::vector({oc, ic, ksize, ksize})}; + const DDim& out_dim = + lite::DDim{std::vector({batch_size, oc, ih, iw})}; + // element wise bias + const DDim& bias_dim = lite::DDim{std::vector({oc})}; + + param.x->Resize(input_dim); + param.filter->Resize(filter_dim); + param.output->Resize(out_dim); + if (bias_flag) { + param.bias->Resize(bias_dim); + } + + kernel->SetParam(param); + + size_t input_image_width = iw * ((ic + 3) / 4); + size_t input_image_height = ih * batch_size; + + size_t out_image_width = ow * ((oc + 3) / 4); + size_t out_image_height = oh * batch_size; + + size_t bias_image_width = ow * ((oc + 3) / 4); + size_t bias_image_height = oh * batch_size; + + size_t filter_image_width = ksize * ((oc + 3) / 4); + size_t filter_image_height = ic * ksize; + + auto* input_data = input.mutable_data(input_image_width, + input_image_height); + auto* filter_data = filter.mutable_data( + filter_image_width, filter_image_height); + bias.mutable_data(bias_image_width, bias_image_height); + auto* bias_data = bias.mutable_data(bias_image_width, + bias_image_height); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + + LOG(INFO) << "map input ..."; + auto* mapped_input = + static_cast(TargetWrapperCL::MapImage(input_data, + input_image_width, + input_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + + LOG(INFO) << "map filter ..."; + auto* mapped_filter = + static_cast(TargetWrapperCL::MapImage(filter_data, + filter_image_width, + filter_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + + std::default_random_engine engine; + std::uniform_real_distribution gen(-5, 5); + std::vector input_v(batch_size * ic * ih * iw); + std::vector filter_v(oc * ic * ksize * ksize); + std::vector output_v(batch_size * oc * ih * iw); + std::vector bias_v(oc); + + float* input_v_data = &input_v[0]; + float* filter_v_data = &filter_v[0]; + float* output_v_data = &output_v[0]; + float* bias_v_data = &bias_v[0]; + + LOG(INFO) << "gen input and filter ..."; + + for (auto& i : input_v) { + i = gen(engine); + } + for (auto& f : filter_v) { + f = gen(engine); + } + + LOG(INFO) << "after gen input and filter ..."; + LOG(INFO) << "input_v.size(): " << input_v.size(); + LOG(INFO) << "filter_v.size(): " << filter_v.size(); + LOG(INFO) << "output_v.size(): " << output_v.size(); + LOG(INFO) << "bias_v.size(): " << bias_v.size(); + LOG(INFO) << "input_dim.production(): " << input_dim.production(); + LOG(INFO) << "filter_dim.production(): " << filter_dim.production(); + LOG(INFO) << "out_dim.production(): " << out_dim.production(); + LOG(INFO) << "bias_dim.production(): " << bias_dim.production(); + LOG(INFO) << "4 * input_image_height * input_image_width: " + << 4 * input_image_height * input_image_width; + LOG(INFO) << "4 * filter_image_width * filter_image_height: " + << 4 * filter_image_width * filter_image_height; + + CHECK(input_dim.production() == input_v.size()); + CHECK_LE(input_dim.production(), 4 * input_image_height * input_image_width); + CHECK(filter_dim.production() == filter_v.size()); + CHECK_LE(filter_dim.production(), + 4 * filter_image_width * filter_image_height); + + paddle::lite::CLImageConverterDefault default_convertor; + LOG(INFO) << "set mapped input ..."; + default_convertor.NCHWToImage(input_v_data, mapped_input, input_dim); + LOG(INFO) << "set mapped filter ..."; + paddle::lite::CLImageConverterNWBlock nw_convertor; + nw_convertor.NCHWToImage(filter_v_data, mapped_filter, filter_dim); + + LOG(INFO) << "resize output ..."; + output.Resize(out_dim); + + // cpu conv basic calc + lite::Tensor out_ref; + out_ref.Resize(out_dim); + + float* mapped_bias = nullptr; + if (bias_flag) { + mapped_bias = + static_cast(TargetWrapperCL::MapImage(bias_data, + bias_image_width, + bias_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + + for (int i = 0; i < bias_dim.production(); ++i) { + bias_v[i] = static_cast(gen(engine)); + } + CLImageConverterFolder folder_convertor; + folder_convertor.NCHWToImage(bias_v_data, mapped_bias, bias_dim); + } + LOG(INFO) << "prepare kernel ready"; + + LOG(INFO) << "kernel launch ..."; + kernel->Launch(); + LOG(INFO) << "mutable output ..."; + auto* output_data = output.mutable_data(out_image_width, + out_image_height); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.output->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the target cl tensor."; + } + + auto* mapped_output = + static_cast(TargetWrapperCL::MapImage(output_data, + out_image_width, + out_image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + LOG(INFO) << "mutable_data out_ref_data: "; + + // run cpu ref + auto* out_ref_data = out_ref.mutable_data(TARGET(kARM)); + + LOG(INFO) << " conv_basic beigin ..... "; + + conv_basic(input_v_data, + out_ref_data, + batch_size, + oc, + oh, + ow, + ic, + ih, + iw, + filter_v_data, + bias_v_data, // mapped_bias, + group, + ksize, + ksize, + stride, + stride, + dilation, + dilation, + pad, + pad, + bias_flag, + relu_flag); + LOG(INFO) << " conv_basic end ..... "; + + LOG(INFO) << " out_dim: " << out_dim; + const DDim& out_image_dims = lite::DDim{ + std::vector({static_cast(out_image_width), + static_cast(out_image_height)})}; + default_convertor.ImageToNCHW( + mapped_output, output_v_data, out_image_dims, out_dim); + for (int i = 0; i < out_dim.production(); i++) { + EXPECT_NEAR(output_v_data[i], out_ref_data[i], 1e-3); + if (abs(output_v_data[i] - out_ref_data[i]) > 1e-3) { + LOG(FATAL) << "error idx:" << i; + } + } + + TargetWrapperCL::Unmap(output_data, mapped_output); + TargetWrapperCL::Unmap(filter_data, mapped_filter); + TargetWrapperCL::Unmap(input_data, mapped_input); + if (bias_flag) { + if (mapped_bias) { + TargetWrapperCL::Unmap(bias_data, mapped_bias); + } + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(conv2d_1x1, kOpenCL, kFloat, kNHWC, image2d); diff --git a/lite/kernels/opencl/image_helper.h b/lite/kernels/opencl/image_helper.h index d164f1ef777a02e5fd3bd33f5cab117de17834b8..d0d282250d1c5658bc8f684b52b4b0d140895833 100644 --- a/lite/kernels/opencl/image_helper.h +++ b/lite/kernels/opencl/image_helper.h @@ -40,6 +40,39 @@ static std::map InitImageDimInfoWith( size_t height = H * N; return std::map({{"width", width}, {"height", height}}); } +inline static int maptofactor(int i, int factor) { + return (i + factor - 1) / factor; +} + +static std::vector DefaultWorkSize(const DDim& image_dim, + const DDim& image_shape) { + // n c h w + // auto image_dim = image.dims(); + if (image_dim.size() == 4) { + auto n = image_dim[0]; + auto h = image_dim[2]; + auto w = image_dim[3]; + auto image_width = image_shape[0]; + size_t work_size_0 = image_width / w; + size_t work_size_1 = w; + size_t work_size_2 = n * h; + return {work_size_0, work_size_1, work_size_2}; + } else if (image_dim.size() == 2) { + auto h = image_dim[0]; + auto w = image_dim[1]; + return {1, + static_cast(image_shape[0]), + static_cast(image_shape[1])}; + } else if (image_dim.size() == 1) { + return {1, static_cast(image_shape[0]), 1}; + } else if (image_dim.size() == 3) { + size_t c = image_dim[0]; + size_t h = image_dim[1]; + size_t w = image_dim[2]; + return {(c + 3) / 4, w, h}; + } + LOG(FATAL) << " not support this dim, need imp "; +} } // namespace opencl } // namespace kernels diff --git a/lite/kernels/opencl/layout_compute.cc b/lite/kernels/opencl/layout_compute.cc index a2869457fc8dabdfb39d3d447404c0a6f6f77375..e2e1530ba62010fdb930ccdf852cf2fc2ebc39a5 100644 --- a/lite/kernels/opencl/layout_compute.cc +++ b/lite/kernels/opencl/layout_compute.cc @@ -126,6 +126,102 @@ class LayoutComputeBufferChwToImage2DHwc std::shared_ptr event_{new cl::Event}; }; +// buffer chw 2 image2d nw +class LayoutComputeBufferChwToImage2DNw + : public KernelLite { + public: + using param_t = operators::LayoutParam; + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "buffer/layout_kernel.cl", build_options_); + } + + void Run() override { + auto& param = Param(); + auto* x_data = param.x->data(); + auto x_dims = param.x->dims(); + + CHECK(x_dims.size() == 4) << " Tensor dim is not 4."; + size_t image_width = x_dims[3] * ((x_dims[0] + 3) / 4); + size_t image_height = x_dims[1] * x_dims[2]; + + auto* y_data = + param.y->mutable_data(image_width, image_height); + auto y_dims = param.y->dims(); + + // out info + std::vector new_dims = {1, 1, 1, 1}; + for (int tidx = 0; tidx < x_dims.size(); ++tidx) { + new_dims[4 - x_dims.size() + tidx] = x_dims[tidx]; + } + + const int out_N = new_dims[0]; + const int out_C = new_dims[1]; + const int out_H = new_dims[2]; + const int out_W = new_dims[3]; + + const int Stride2 = out_C * out_H * out_W; + const int Stride1 = out_H * out_W; + const int Stride0 = out_W; + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_data); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_H)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_W)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(out_N)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride0)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride1)); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(Stride2)); + CL_CHECK_FATAL(status); + + VLOG(4) << "gws:[3D]" << ((out_N + 3) / 4) << " " << out_W << " " + << (out_C * out_H); + auto global_work_size = + cl::NDRange{static_cast((out_N + 3) / 4), // N blocks + static_cast(out_W), // w + static_cast(out_C * out_H)}; // ch + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(y_data, event_); + context.cl_context()->GetCommandQueue().finish(); + // auto image_shape = InitImageDimInfoWith(x_dims); + } + + std::string doc() const override { + return "Trans Layout from cl::Buffer(NCHW) to cl::Image2D(CLNW)"; + } + + private: + std::string kernel_func_name_{"buffer_to_image2d_nw"}; + std::string build_options_{"-DCL_DTYPE_float "}; + std::shared_ptr event_{new cl::Event}; +}; + class LayoutComputeImage2DHwcToBufferChw : public KernelLite { public: @@ -293,3 +389,21 @@ REGISTER_LITE_KERNEL( PRECISION(kAny), DATALAYOUT(kNCHW))}) .Finalize(); + +// [hwc] -> [chw] +REGISTER_LITE_KERNEL( + layout_once, + kOpenCL, + kFloat, + kImageNW, + paddle::lite::kernels::opencl::LayoutComputeBufferChwToImage2DNw, + buffer_chw_to_image2d_nw_opencl_fp32) + .BindInput("Input", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageNW))}) + .Finalize(); diff --git a/lite/kernels/opencl/layout_compute_test.cc b/lite/kernels/opencl/layout_compute_test.cc index 3e8dd78f616d4d1e3fabf51ba8d3ddf43dd561f1..3968e23d6be10bf050bcfc478d278398bf16fd7e 100644 --- a/lite/kernels/opencl/layout_compute_test.cc +++ b/lite/kernels/opencl/layout_compute_test.cc @@ -144,7 +144,141 @@ TEST(layout, compute) { // nothing to do. #endif } +TEST(layout, compute_buffer2image2dnw) { +#ifdef LOOP_TEST + for (int n = 1; n <= 100; n += 21) { + for (auto c : {1, 3}) { + for (int h = 1; h <= 100; h += 13) { + for (int w = 1; w <= 100; w += 17) { +#else + const int n = 1; + const int c = 1; + const int h = 1; + const int w = 100; +#endif // LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " " + << h << " " << w << " ========"; + // set layout kernels + auto buf_to_img_nw_kernels = + KernelRegistry::Global().Create("layout_once", + TARGET(kOpenCL), + PRECISION(kFloat), + DATALAYOUT(kImageNW)); + ASSERT_FALSE(buf_to_img_nw_kernels.empty()); + auto buf_to_img_nw_kernel = std::move(buf_to_img_nw_kernels.front()); + LOG(INFO) << "get 1st kernel: " << buf_to_img_nw_kernel->doc(); + + // set tensors about op param + operators::LayoutParam bufferToImageParam; + lite::Tensor x, y, cpu_y; + bufferToImageParam.x = &x; + bufferToImageParam.y = &y; + + const DDim x_dim = DDim(std::vector{n, c, h, w}); + x.Resize(x_dim); + y.Resize(x_dim); // useless for image2D + cpu_y.Resize(x_dim); + + // initialize tensors + LOG(INFO) << "initialize tensors"; + + // mute in buffer + auto* x_data = x.mutable_data(TARGET(kOpenCL)); + // mute out image nw + size_t image_width = w * ((n + 3) / 4); + size_t image_height = c * h; + auto* y_data = + y.mutable_data(image_width, image_height); + auto* cpu_y_data = + cpu_y.mutable_data(image_width, image_height); + + auto* mapped_x = static_cast(TargetWrapperCL::Map( + x_data, 0, sizeof(float) * x_dim.production())); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + + auto* mapped_y = static_cast( + TargetWrapperCL::MapImage(y_data, + image_width, + image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + + auto* mapped_cpu_y = static_cast( + TargetWrapperCL::MapImage(cpu_y_data, + image_width, + image_height, + cl_image2d_row_pitch, + cl_image2d_slice_pitch)); + + // random datas + std::default_random_engine engine; + std::uniform_real_distribution gen(-5, 5); + + for (int i = 0; i < x_dim.production(); ++i) { + mapped_x[i] = gen(engine); + } + + // gen cpu y_data + CLImageConverterNWBlock nw_converter; + nw_converter.NCHWToImage(mapped_x, mapped_cpu_y, x_dim); + + // set context and kernel args + LOG(INFO) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + // set kernel params + buf_to_img_nw_kernel->SetParam(bufferToImageParam); + + std::unique_ptr buf_to_img_context(new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context->As())); + + // set context + buf_to_img_nw_kernel->SetContext(std::move(buf_to_img_context)); + + // run kernels + LOG(INFO) << "run kernel: buf_to_img_kernel"; + buf_to_img_nw_kernel->Launch(); + +// result +#ifdef PRINT_RESULT + LOG(INFO) << "---- print result ----"; + for (int eidx = 0; i < x_dim.production(); ++eidx) { + std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] + << std::endl; + } +#endif // PRINT_RESULT + // check result: compare input and output + for (int eidx = 0; eidx < x_dim.production(); eidx++) { + EXPECT_NEAR(mapped_cpu_y[eidx], mapped_y[eidx], 1e-3); + if (abs(mapped_cpu_y[eidx] - mapped_y[eidx]) > 1e-3) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx + << " / " << x_dim.production() << ", mapped_x[" << eidx + << "]:" << mapped_cpu_y[eidx] << ", mapped_y[" << eidx + << "]:" << mapped_y[eidx]; + break; + } + } + + // free + LOG(INFO) << "free: unmap x, y"; + TargetWrapperCL::Unmap(x_data, mapped_x); + TargetWrapperCL::Unmap(y_data, mapped_y); + +#ifdef LOOP_TEST + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} } // namespace lite } // namespace paddle @@ -152,3 +286,8 @@ USE_LITE_KERNEL( layout, kOpenCL, kAny, kNHWC, buffer_chw_to_image2d_hwc_opencl_fp32); USE_LITE_KERNEL( layout, kOpenCL, kAny, kNCHW, image2d_hwc_to_buffer_chw_opencl_fp32); +USE_LITE_KERNEL(layout_once, + kOpenCL, + kFloat, + kImageNW, + buffer_chw_to_image2d_nw_opencl_fp32);