diff --git a/lite/backends/opencl/cl_kernel/image/pad2d_kernel.cl b/lite/backends/opencl/cl_kernel/image/pad2d_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..e65aad3d19bc674aff2f71d2403e611cd247abf1 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/pad2d_kernel.cl @@ -0,0 +1,108 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void pad2d_constant( + __read_only image2d_t input, __write_only image2d_t output, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_h0, const int pad_h1, + const int pad_w0, const int pad_w1, + const float pad_value) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + + int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int x = out_w - pad_w0; + int y = out_h - pad_h0; + + if (x < 0 || y < 0 || x >= in_width || y >= in_height) { + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, (CL_DTYPE4)(pad_value)); + } else { + int2 coor = (int2)(out_c * in_width + x, out_n * in_height + y); + CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, pixel); + } +} + +__kernel void pad2d_reflect( + __read_only image2d_t input, __write_only image2d_t output, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_h0, const int pad_h1, + const int pad_w0, const int pad_w1, + const float pad_value) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + + int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int x = out_w - pad_w0; + int y = out_h - pad_h0; + + x = abs(x); + y = abs(y); + x = x < in_width ? x : 2 * in_width - 2 - x; + y = y < in_height ? y : 2 * in_height - 2 - y; + int2 coor = (int2)(out_c * in_width + x, out_n * in_height + y); + CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, pixel); +} + +__kernel void pad2d_edge( + __read_only image2d_t input, __write_only image2d_t output, + const int in_height, const int in_width, + const int out_height, const int out_width, + const int pad_h0, const int pad_h1, + const int pad_w0, const int pad_w1, + const float pad_value) { + + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + + int2 output_pos = (int2)(mad24(out_c, out_width, out_w), out_nh); + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int x = out_w - pad_w0; + int y = out_h - pad_h0; + + x = x > 0 ? x : 0; + x = x < in_width ? x : in_width - 1; + y = y > 0 ? y : 0; + y = y < in_height ? y : in_height - 1; + int2 coor = (int2)(out_c * in_width + x, out_n * in_height + y); + CL_DTYPE4 pixel = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coor); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, pixel); +} diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 2cf0aae42009a8a92703d6690a61ac8a2296e290..25afb2fc399c6a4da8775440c1602031061267f7 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -32,6 +32,7 @@ add_kernel(bilinear_interp_opencl OPENCL basic SRCS bilinear_interp_image_comput add_kernel(slice_opencl OPENCL basic SRCS slice_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(dropout_opencl OPENCL basic SRCS dropout_image_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kernel_deps}) # extra # wait to add ... @@ -92,7 +93,10 @@ lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_te DEPS instance_norm_opencl op_registry program context) lite_cc_test(test_dropout_image_opencl SRCS dropout_image_compute_test.cc - DEPS dropout_opencl op_registry program context) + DEPS dropout_opencl op_registry program context) + +lite_cc_test(test_pad2d_image_opencl SRCS pad2d_image_compute_test.cc + DEPS pad2d_opencl layout_opencl op_registry program context) ###################### # buffer kernel # ###################### diff --git a/lite/kernels/opencl/pad2d_image_compute.cc b/lite/kernels/opencl/pad2d_image_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f4838149d1e2364baf0b1b2286fef4a74ee9a4b --- /dev/null +++ b/lite/kernels/opencl/pad2d_image_compute.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/logging.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { +class Pad2dCompute : public KernelLite { + public: + using param_t = operators::Pad2dParam; + + std::string doc() const override { + return "Pad2d using cl::Image2D(ImageDefault/RGBA), kFP16"; + } + + void PrepareForRun() override { + pad2d_param_ = param_.get_mutable(); + + if (pad2d_param_->mode == "constant") { + kernel_func_name_ = "pad2d_constant"; + } else if (pad2d_param_->mode == "reflect") { + kernel_func_name_ = "pad2d_reflect"; + } else if (pad2d_param_->mode == "edge") { + kernel_func_name_ = "pad2d_edge"; + } else { + LOG(FATAL) << "Unknown mode type"; + } + + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/pad2d_kernel.cl", build_options_); + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + } + + void Run() override { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto* x = pad2d_param_->X; + auto* out = pad2d_param_->Out; + auto out_dims = out->dims(); + auto in_dims = x->dims(); + + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int out_h = out_dims[2]; + int out_w = out_dims[3]; + + VLOG(4) << "x->target():" << TargetToStr(x->target()); + VLOG(4) << "out->target():" << TargetToStr(out->target()); + VLOG(4) << "x->dims():" << in_dims; + VLOG(4) << "out->dims():" << out_dims; + + auto out_image_shape = InitImageDimInfoWith(out_dims); + auto* x_img = x->data(); + + auto* out_img = out->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + + VLOG(4) << "out_image_shape[w,h]: " << out_image_shape["width"] << " " + << out_image_shape["height"]; + + VLOG(4) << "in_h: " << in_h << ", in_w: " << in_w; + VLOG(4) << "out_h: " << out_h << ", out_w: " << out_w; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + auto default_work_size = + DefaultWorkSize(out_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + VLOG(4) << "default_work_size: " << default_work_size[0] << ", " + << default_work_size[1] << ", " << default_work_size[2]; + + int pad_h0 = pad2d_param_->paddings[0]; + int pad_h1 = pad2d_param_->paddings[1]; + int pad_w0 = pad2d_param_->paddings[2]; + int pad_w1 = pad2d_param_->paddings[3]; + float pad_value = pad2d_param_->pad_value; + + cl_int status = kernel.setArg(arg_idx++, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, in_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, in_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, out_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, out_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, pad_h0); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, pad_h1); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, pad_w0); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, pad_w1); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, pad_value); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size[0]), + static_cast(default_work_size[1]), + static_cast(default_work_size[2])}; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); + + VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " + << global_work_size[1] << " " << global_work_size[2]; + } + + protected: + param_t* pad2d_param_{nullptr}; + std::string kernel_func_name_{}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace ocl = paddle::lite::kernels::opencl; +REGISTER_LITE_KERNEL( + pad2d, kOpenCL, kFP16, kImageDefault, ocl::Pad2dCompute, ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/pad2d_image_compute_test.cc b/lite/kernels/opencl/pad2d_image_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d1e1e3bb4c8fc80fabacff52b66f20387dd7766f --- /dev/null +++ b/lite/kernels/opencl/pad2d_image_compute_test.cc @@ -0,0 +1,351 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/image_helper.h" + +namespace paddle { +namespace lite { + +void pad2d_ref(const float *x_data, + Tensor *y, + std::string mode, + int pad_h0, + int pad_h1, + int pad_w0, + int pad_w1, + float pad_value) { + auto *out_data = y->mutable_data(); + auto output_dims = y->dims(); + int n = output_dims[0]; + int c = output_dims[1]; + int h = output_dims[2]; + int w = output_dims[3]; + int pad_mode; + if (mode == "constant") { + pad_mode = 0; + } else if (mode == "reflect") { + pad_mode = 2; + } else if (mode == "edge") { + pad_mode = 1; + } else { + LOG(FATAL) << "Unknown mode type"; + } + int in_w = w - pad_w0 - pad_w1; + int in_h = h - pad_h0 - pad_h1; + int spatial_size_out = w * h; + int spatial_size_in = in_w * in_h; +#pragma omp parallel for + for (int i = 0; i < n * c; ++i) { + const float *din_batch = x_data + i * spatial_size_in; + float *dout_batch = out_data + i * spatial_size_out; + int in_y = 0; + int in_x = 0; + for (int y = 0; y < h; ++y) { + for (int x = 0; x < w; ++x) { + switch (pad_mode) { + case 0: + in_y = y - pad_h0; + in_x = x - pad_w0; + dout_batch[y * w + x] = + (in_x >= 0 && in_x < in_w) && (in_y >= 0 && in_y < in_h) + ? din_batch[in_y * in_w + in_x] + : pad_value; + break; + case 1: + in_x = std::min(std::max(pad_w0, x), in_w + pad_w0 - 1) - pad_w0; + in_y = std::min(std::max(pad_h0, y), in_h + pad_h0 - 1) - pad_h0; + dout_batch[y * w + x] = din_batch[in_y * in_w + in_x]; + break; + case 2: + in_y = y - pad_h0; + in_x = x - pad_w0; + in_y = std::max(in_y, -in_y); + in_y = std::min(in_y, 2 * in_h - in_y - 2); + in_x = std::max(in_x, -in_x); + in_x = std::min(in_x, 2 * in_w - in_x - 2); + dout_batch[y * w + x] = din_batch[in_y * in_w + in_x]; + break; + default: + LOG(ERROR) << "ERROR: unknown pad mode:" << pad_mode; + } + } + } + } +} + +#define LOOP_TEST +// #define PRINT_RESULT +TEST(pad2d_image2d, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img) -> " + "pad2d(img) -> " + "layout(img2buf) " + "-> host"; + +#ifdef LOOP_TEST + for (int n : {1, 3}) { + for (auto c : {1, 3}) { + for (int h : {12, 112}) { + for (int w : {12, 112}) { + for (int pad_h0 : {0, 1, 2}) { + for (int pad_h1 : {0, 1, 2}) { + for (int pad_w0 : {0, 1, 2}) { + for (int pad_w1 : {0, 1, 2}) { + for (float pad_value : {10.f}) { + for (std::string pad_mode : + {"constant", "reflect", "edge"}) { +#else + const int n = 1; + const int c = 3; + const int h = 12; + const int w = 112; + const int pad_h0 = 1; + const int pad_h1 = 2; + const int pad_w0 = 1; + const int pad_w1 = 2; + const float pad_value = 10.f; + std::string pad_mode = "reflect"; +#endif // LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " + << c << " " << h << " " << w; + LOG(INFO) << "======== pad_h0: " << pad_h0 + << ", pad_h1: " << pad_h1 + << ", pad_w0: " << pad_w0 + << ", pad_w1: " << pad_w1 + << ", pad_value: " << pad_value + << ", pad_mode: " << pad_mode; + // set layout kernels + auto buf_to_img_kernels = KernelRegistry::Global().Create( + "layout", + TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kImageDefault)); + auto img_to_buf_kernels = + KernelRegistry::Global().Create("layout", + TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kNCHW)); + auto pad2d_img_kernels = KernelRegistry::Global().Create( + "pad2d", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(pad2d_img_kernels.empty()); + + auto buf_to_img_kernel = + std::move(buf_to_img_kernels.front()); + auto img_to_buf_kernel = + std::move(img_to_buf_kernels.front()); + auto pad2d_img_kernel = + std::move(pad2d_img_kernels.front()); + LOG(INFO) << "get 1st kernel: " + << buf_to_img_kernel->doc(); + LOG(INFO) << "get 2nd kernel: " + << img_to_buf_kernel->doc(); + LOG(INFO) << "get 3rd kernel: " + << pad2d_img_kernel->doc(); + + // set tensors about op param + LOG(INFO) << "set tensors about op param"; + // layout(buf->img): x -> pad2d_in + // pad2d(img): pad2d_in -> pad2d_out + // layout(img->buf): pad2d_out -> y + lite::Tensor x, y, pad2d_in, pad2d_out, y_ref; + operators::LayoutParam BufferToImageParam; + operators::LayoutParam ImageToBufferParam; + BufferToImageParam.x = &x; + BufferToImageParam.y = &pad2d_in; + ImageToBufferParam.x = &pad2d_out; + ImageToBufferParam.y = &y; + operators::Pad2dParam Pad2dParam; + Pad2dParam.X = &pad2d_in; + Pad2dParam.Out = &pad2d_out; + Pad2dParam.paddings = {pad_h0, pad_h1, pad_w0, pad_w1}; + Pad2dParam.pad_value = pad_value; + Pad2dParam.mode = pad_mode; + + int64_t out_h = h + pad_h0 + pad_h1; + int64_t out_w = w + pad_w0 + pad_w1; + const DDim x_dim = + DDim(std::vector{n, c, h, w}); + const DDim y_dim = DDim( + std::vector{n, c, out_h, out_w}); + x.Resize(x_dim); + y.Resize(y_dim); + pad2d_in.Resize(x_dim); + pad2d_out.Resize(y_dim); + y_ref.Resize(y_dim); + auto pad2d_image2d_shape = + paddle::lite::kernels::opencl::InitImageDimInfoWith( + x_dim); + + // initialize tensors + LOG(INFO) << "initialize tensors"; + auto *x_data = + x.mutable_data(TARGET(kOpenCL)); + auto *y_data = + y.mutable_data(TARGET(kOpenCL)); + auto *y_data_ref = + y_ref.mutable_data(TARGET(kARM)); + auto *mapped_x = + static_cast(TargetWrapperCL::Map( + x_data, 0, sizeof(float) * x_dim.production())); + auto *mapped_y = + static_cast(TargetWrapperCL::Map( + y_data, 0, sizeof(float) * y_dim.production())); + std::default_random_engine engine; + std::uniform_real_distribution dist(-1, 1); + for (int i = 0; i < x_dim.production(); ++i) { + mapped_x[i] = dist(engine); + } + auto *pad2d_in_data = + pad2d_in.mutable_data( + pad2d_image2d_shape["width"], + pad2d_image2d_shape["height"]); + auto *pad2d_out_data = + pad2d_out.mutable_data(y_dim[3], + y_dim[2]); + + // set context and kernel args + LOG(INFO) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + buf_to_img_kernel->SetParam(BufferToImageParam); + std::unique_ptr buf_to_img_context( + new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context->As())); + buf_to_img_kernel->SetContext( + std::move(buf_to_img_context)); + + img_to_buf_kernel->SetParam(ImageToBufferParam); + std::unique_ptr img_to_buf_context( + new KernelContext); + context->As().CopySharedTo( + &(img_to_buf_context->As())); + img_to_buf_kernel->SetContext( + std::move(img_to_buf_context)); + + pad2d_img_kernel->SetParam(Pad2dParam); + std::unique_ptr pad2d_img_context( + new KernelContext); + context->As().CopySharedTo( + &(pad2d_img_context->As())); + pad2d_img_kernel->SetContext( + std::move(pad2d_img_context)); + + // run kernels + LOG(INFO) << "run kernel: buf_to_img_kernel"; + buf_to_img_kernel->Launch(); + LOG(INFO) << "run kernel: pad2d_img_kernel"; + pad2d_img_kernel->Launch(); + LOG(INFO) << "run kernel: img_to_buf_kernel"; + img_to_buf_kernel->Launch(); + + // wait for opencl + auto *wait_list = + context->As().cl_wait_list(); + auto *out_ptr = + ImageToBufferParam.y->data(); + auto it = wait_list->find(out_ptr); + + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto &event = *(it->second); + event.wait(); + } else { + LOG(FATAL) + << "Could not find the sync event for the target " + "cl tensor."; + } + + // compute ref cpu + pad2d_ref(mapped_x, + &y_ref, + pad_mode, + pad_h0, + pad_h1, + pad_w0, + pad_w1, + pad_value); +// result +#ifdef PRINT_RESULT + LOG(INFO) + << "---- print kernel result (input -> output) ----"; + for (int eidx = 0; eidx < x_dim.production(); ++eidx) { + std::cout << mapped_x[eidx] << " "; + } + std::cout << std::endl; + for (int eidx = 0; eidx < y_dim.production(); ++eidx) { + std::cout << mapped_y[eidx] << " "; + } + std::cout << std::endl; + for (int eidx = 0; eidx < y_dim.production(); ++eidx) { + std::cout << y_data_ref[eidx] << " "; + } + std::cout << std::endl; +#endif // PRINT_RESULT + // check result: compare kernel output and cpu + // output(y_data_ref) + for (int eidx = 0; eidx < y_dim.production(); eidx++) { + EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-3); + if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-3) { + LOG(FATAL) << "1st diff in this case at eidx[from 0]:" + << eidx << " / " << y_dim.production() + << ", y_data_ref[" << eidx + << "]:" << y_data_ref[eidx] + << ", mapped_y[" << eidx + << "]:" << mapped_y[eidx]; + break; + } + } + + // free + LOG(INFO) << "free: unmap x, y"; + TargetWrapperCL::Unmap(x_data, mapped_x); + TargetWrapperCL::Unmap(y_data, mapped_y); +#ifdef LOOP_TEST + } // pad_mode + } // pad_value + } // pad_w1 + } // pad_w0 + } // pad_h1 + } // pad_h0 + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle + +// pad2d image2d fp32 +USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); +USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW); + +// pad image2d fp16 +USE_LITE_KERNEL(pad2d, kOpenCL, kFP16, kImageDefault, ImageDefault);