From 732bb91baf1c9e6f6f6e90e1c325eb87376a47a7 Mon Sep 17 00:00:00 2001 From: dustybluebird Date: Mon, 22 Jun 2020 12:19:28 +0800 Subject: [PATCH] =?UTF-8?q?[OPENCL]add=20transpose=E3=80=81transpose2=20ke?= =?UTF-8?q?rnel,=20test=3Ddevelop?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cl_kernel/image/transpose_kernel.cl | 160 +++++++ lite/kernels/opencl/CMakeLists.txt | 4 + .../kernels/opencl/transpose_image_compute.cc | 395 ++++++++++++++++++ .../opencl/transpose_image_compute_test.cc | 172 ++++++++ 4 files changed, 731 insertions(+) create mode 100644 lite/backends/opencl/cl_kernel/image/transpose_kernel.cl create mode 100644 lite/kernels/opencl/transpose_image_compute.cc create mode 100644 lite/kernels/opencl/transpose_image_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/transpose_kernel.cl b/lite/backends/opencl/cl_kernel/image/transpose_kernel.cl new file mode 100644 index 0000000000..b8533076b7 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/transpose_kernel.cl @@ -0,0 +1,160 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void transpose_4d(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int in_W) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = 1; + const int out_h = out_nh % out_H; + const int out_c0 = out_c * 4; + const int out_c1 = out_c * 4 + 1; + const int out_c2 = out_c * 4 + 2; + const int out_c3 = out_c * 4 + 3; + + const int in_n = out_n; + const int in_c = out_w * 0.25; + const int in_h0 = out_c0; + const int in_h1 = out_c1; + const int in_h2 = out_c2; + const int in_h3 = out_c3; + const int in_w = out_h; + + int2 output_pos; + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + int2 input_pos0; + int2 input_pos1; + int2 input_pos2; + int2 input_pos3; + + input_pos0.x = in_W * in_c + in_w; + input_pos0.y = in_n * in_h0; + + input_pos1.x = in_W * in_c + in_w; + input_pos1.y = in_n * in_h1; + + input_pos2.x = in_W * in_c + in_w; + input_pos2.y = in_n * in_h2; + + input_pos3.x = in_W * in_c + in_w; + input_pos3.y = in_n * in_h3; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 input0; + CL_DTYPE4 input1; + CL_DTYPE4 input2; + CL_DTYPE4 input3; + CL_DTYPE4 output; + input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos0); + + if (out_w % 4 == 0) { + output.x = input0.x; + } else if (out_w % 4 == 1) { + output.x = input0.y; + } else if (out_w % 4 == 2) { + output.x = input0.z; + } else { + output.x = input0.w; + } + if (out_C - out_c * 4 >= 2) { + input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos1); + if(out_w % 4 == 0) { + output.y = input1.x; + } else if(out_w % 4 == 1) { + output.y = input1.y; + } else if(out_w % 4 == 2) { + output.y = input1.z; + } else { + output.y = input1.w; + } + } else { + output.y = 0.0f; + } + + if (out_C - out_c * 4 >= 3) { + input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos2); + if (out_w % 4 == 0){ + output.z = input2.x; + } else if (out_w % 4 == 1) { + output.z = input2.y; + } else if (out_w % 4 == 2) { + output.z = input2.z; + } else { + output.z = input2.w; + } + } else { + output.z = 0.0f; + } + + if (out_C - out_c * 4 >= 4) { + input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos3); + if (out_w % 4 == 0) { + output.w = input3.x; + } else if (out_w % 4 == 1) { + output.w = input3.y; + } else if (out_w % 4 == 2) { + output.w = input3.z; + } else { + output.w = input3.w; + } + } else { + output.w = 0.0f; + } + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} + +__kernel void transpose(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_C, + __private const int out_H, + __private const int out_W, + __private const int in_W) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + const int out_n = 1; + const int out_h = out_nh % out_H; + + const int in_n = 1; + const int in_c = out_c; + const int in_w = out_h; + const int in_h = out_w; + + int2 input_pos; + int2 output_pos; + input_pos.x = in_c * in_W + in_w; + input_pos.y = in_n * in_h; + + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_n * out_h; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 input; + CL_DTYPE4 output; + input = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos); + + output = input; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, input); +} \ No newline at end of file diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 600d0d2255..81e1a4d756 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -21,6 +21,7 @@ add_kernel(fusion_elementwise_sub_activation_opencl add_kernel(pool_opencl OPENCL basic SRCS pool_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(activation_opencl OPENCL basic SRCS activation_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(reshape_opencl OPENCL basic SRCS reshape_image_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(transpose_opencl OPENCL basic SRCS transpose_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(conv_opencl OPENCL basic SRCS conv_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(layout_opencl OPENCL basic SRCS layout_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(concat_opencl OPENCL basic SRCS concat_image_compute.cc DEPS ${cl_kernel_deps}) @@ -67,6 +68,9 @@ lite_cc_test(test_scale_image_opencl SRCS scale_image_compute_test.cc lite_cc_test(test_reshape_image_opencl SRCS reshape_image_compute_test.cc DEPS reshape_opencl op_registry program context) +lite_cc_test(test_transpose_image_opencl SRCS transpose_image_compute_test.cc + DEPS transpose_opencl layout_opencl op_registry program context) + lite_cc_test(test_concat_image_opencl SRCS concat_image_compute_test.cc DEPS concat_opencl layout_opencl op_registry program context) diff --git a/lite/kernels/opencl/transpose_image_compute.cc b/lite/kernels/opencl/transpose_image_compute.cc new file mode 100644 index 0000000000..31184092ef --- /dev/null +++ b/lite/kernels/opencl/transpose_image_compute.cc @@ -0,0 +1,395 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/logging.h" +#include "lite/utils/replace_stl/stream.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" + +#undef LITE_WITH_LOG + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +// transpose operator +class TransposeComputeFloatImage + : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void PrepareForRun() override { + auto& param = *param_.get_mutable(); + Tensor* const output = param.output; + const DDimLite& out_dims = output->dims(); + if (out_dims.size() == 4) { + kernel_func_name_ = "transpose_4d"; + } else { + kernel_func_name_ = "transpose"; + } + auto& context = ctx_->As(); + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + context.cl_context()->AddKernel(kernel_func_name_, + "image/transpose_kernel.cl", + build_options_, + time_stamp_); + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + + void Run() override { + auto& param = *param_.get_mutable(); + const Tensor* const x = param.x; + const auto x_dims = x->dims(); + const std::map& input_image_shape = + InitImageDimInfoWith(x_dims); + const int64_t& input_image_width = input_image_shape.at("width"); + const int64_t& input_image_height = input_image_shape.at("height"); + const cl::Image2D* const x_image = x->data(); + + Tensor* const output = param.output; + const DDimLite& out_dims = output->dims(); + VLOG(4) << "out_dims= " << out_dims; + const std::map& out_image_shape = + InitImageDimInfoWith(out_dims); + cl::Image2D* const out_image = output->mutable_data( + out_image_shape.at("width"), out_image_shape.at("height")); +#ifdef LITE_WITH_LOG + VLOG(4) << "out_dims= " << out_dims; +#endif + const std::vector& default_work_size = DefaultWorkSize( + out_dims, + DDim(std::vector{ + static_cast(out_image_shape.at("width")), + static_cast(out_image_shape.at("height"))})); + + int out_C = 0, out_H = 0, out_W = 0, in_W = 0; + if (param.output->dims().size() == 4) { + out_C = out_dims[1]; + out_H = out_dims[2]; + out_W = out_dims[3]; + in_W = x_dims[3]; + } else if (param.output->dims().size() == 3) { + out_C = out_dims[0]; + out_H = out_dims[1]; + out_W = out_dims[2]; + in_W = x_dims[2]; + } else if (param.output->dims().size() == 2) { + out_C = 1; + out_H = out_dims[0]; + out_W = out_dims[1]; + in_W = x_dims[1]; + } + +#ifdef LITE_WITH_LOG + VLOG(4) << "out_C=" << out_C; + VLOG(4) << "out_H=" << out_H; + VLOG(4) << "out_W=" << out_W; + VLOG(4) << "in_W=" << in_W; + VLOG(4) << "default_work_size= " << default_work_size[0] << ", " + << default_work_size[1] << ", " << default_work_size[2]; +#endif + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + +#ifdef LITE_WITH_LOG + VLOG(4) << TargetToStr(x->target()); + VLOG(4) << TargetToStr(param.output->target()); +#endif + + int arg_idx = 0; + cl_int status; + status = kernel.setArg(arg_idx, *x_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_image); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, out_C); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, out_H); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, out_W); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, in_W); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size.data()[0]), + static_cast(default_work_size.data()[1]), + static_cast(default_work_size.data()[2])}; + + status = EnqueueNDRangeKernel(context, + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status); + } + + private: + std::string kernel_func_name_{"transpose"}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; +}; + +// transpose2 operator +class Transpose2ComputeFloatImage + : public KernelLite { + public: + using param_t = operators::TransposeParam; + + void PrepareForRun() override {} + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {} +#endif + + bool IsShuffleChannel(const std::vector& axis) { + bool is_shuffle_channel = true; + if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) { + for (int i = 3; i < axis.size(); ++i) { + if (axis[i] != i) { + is_shuffle_channel = false; + break; + } + } + } else { + return false; + } + return is_shuffle_channel; + } + + template + void DeviceTensorToHostTensor(const Tensor* device_tensor, + Tensor* host_tensor) { + host_tensor->Resize(device_tensor->dims()); + Dtype* host_ptr = host_tensor->mutable_data(); + CLRuntime::Global()->command_queue().finish(); + CLImageConverterDefault default_converter; + auto device_tensor_image_dim = + default_converter.InitImageDimInfoWith(device_tensor->dims()); + half_t* image_data = new half_t[device_tensor_image_dim.production() * 4]; + TargetWrapperCL::ImgcpySync(image_data, + device_tensor->data(), + device_tensor_image_dim[0], + device_tensor_image_dim[1], + 0, + 0, + IoDirection::DtoH); + default_converter.ImageToNCHW( + image_data, host_ptr, device_tensor_image_dim, host_tensor->dims()); + delete[] image_data; + } + + template + void HostTensorToDeviceTensor(const Tensor* host_tensor, + Tensor* device_tensor) { + Dtype* host_ptr = const_cast(host_tensor->data()); + CLImageConverterDefault default_converter; + auto device_tensor_image_dim = + default_converter.InitImageDimInfoWith(device_tensor->dims()); + device_tensor->mutable_data( + device_tensor_image_dim[0], device_tensor_image_dim[1]); + half_t* image_data = new half_t[device_tensor->dims().production() * 4]; + default_converter.NCHWToImage(host_ptr, image_data, device_tensor->dims()); + + TargetWrapperCL::ImgcpySync( + device_tensor->mutable_data(), + image_data, + device_tensor_image_dim[0], + device_tensor_image_dim[1], + 0, + 0, + IoDirection::HtoD); + + delete[] image_data; + } + + template + void ShuffleChannelCompute(const operators::TransposeParam& param) { + const Tensor* input = param.x; + Tensor* input_tensor = new Tensor(); + DeviceTensorToHostTensor(input, input_tensor); + Dtype* input_ptr = input_tensor->mutable_data(); + + Tensor* output = param.output; + Tensor* output_tensor = new Tensor(); + output_tensor->Resize(output->dims()); + Dtype* output_ptr = output_tensor->mutable_data(); + + // input and output's shape dimension must >= 2 && <= 6. + const DDim& in_dim = input->dims(); + const DDim& out_dim = output->dims(); + size_t offset = 1; + for (int i = 3; i < param.axis.size(); ++i) { + offset *= in_dim[i]; + } +#pragma omp parallel for collapse(3) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int c1 = 0; c1 < out_dim[1]; ++c1) { + for (int c2 = 0; c2 < out_dim[2]; ++c2) { + size_t out_offset = + ((batch * out_dim[1] + c1) * out_dim[2] + c2) * offset; + size_t in_offset = + ((batch * in_dim[1] + c2) * in_dim[2] + c1) * offset; + memcpy(output_ptr + out_offset, + input_ptr + in_offset, + offset * sizeof(Dtype)); + } + } + } + HostTensorToDeviceTensor(output_tensor, output); + delete input_tensor; + delete output_tensor; + } + + template + void Transpose2Compute(const operators::TransposeParam& param) { + const Tensor* input = param.x; + Tensor* input_tensor = new Tensor(); + DeviceTensorToHostTensor(input, input_tensor); + Dtype* input_ptr = input_tensor->mutable_data(); + + Tensor* output = param.output; + Tensor* output_tensor = new Tensor(); + output_tensor->Resize(output->dims()); + Dtype* output_ptr = output_tensor->mutable_data(); + + // input and output's shape dimension must >= 2 && <= 6. + const DDim& in_dim = input->dims(); + const DDim& out_dim = output->dims(); + + // precompute inverted output dim and strides + size_t rout_dim[6], strides[6]; + auto& axis = param.axis; + int permute = axis.size(); // permute must >=2 && <= 6. + for (int i = 0; i < permute; ++i) { + int k = permute - 1 - i; + strides[k] = 1; + for (int j = axis[i] + 1; j < permute; ++j) { + strides[k] *= in_dim[j]; + } + rout_dim[k] = out_dim[i]; + } + + // unroll the first 2 dimensions + int reamin_dim = 1; + for (int i = 2; i < out_dim.size(); ++i) { + reamin_dim *= out_dim[i]; + } + +#pragma omp parallel for collapse(2) + for (int batch = 0; batch < out_dim[0]; ++batch) { + for (int j = 0; j < out_dim[1]; ++j) { + size_t offset = batch * strides[permute - 1] + j * strides[permute - 2]; + Dtype* out_ptr = output_ptr + (batch * out_dim[1] + j) * reamin_dim; + int indics[4] = {0, 0, 0, 0}; + for (int k = 0; k < reamin_dim; ++k) { + out_ptr[k] = input_ptr[offset]; + indics[0] += 1; + offset += strides[0]; + for (int p = 0; p < permute - 3; ++p) { + if (indics[p] == rout_dim[p]) { + indics[p + 1] += 1; + indics[p] = 0; + offset += strides[p + 1]; + offset -= rout_dim[p] * strides[p]; + } else { + break; + } + } + } + } + } + HostTensorToDeviceTensor(output_tensor, output); + delete input_tensor; + delete output_tensor; + } + + void Run() override { + auto& param = *param_.get_mutable(); + const std::vector axis = param.axis; + + bool shuffle_channel = IsShuffleChannel(axis); + if (shuffle_channel) { + ShuffleChannelCompute(param); + } else { + Transpose2Compute(param); + } + } +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(transpose, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::TransposeComputeFloatImage, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); + +REGISTER_LITE_KERNEL(transpose2, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::Transpose2ComputeFloatImage, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); + +#define LITE_WITH_LOG diff --git a/lite/kernels/opencl/transpose_image_compute_test.cc b/lite/kernels/opencl/transpose_image_compute_test.cc new file mode 100644 index 0000000000..9db9b3732d --- /dev/null +++ b/lite/kernels/opencl/transpose_image_compute_test.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" +#include "lite/operators/reshape_op.h" +#include "lite/utils/logging.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +static inline void TestWithKernel( + const std::unique_ptr& kernel) { + int64_t batch_size = 1; + int64_t ic = 2; + int64_t ih = 3; + int64_t iw = 4; + + int64_t oc = 3; + int64_t oh = 4; + int64_t ow = 2; + + lite::Tensor input, output; + operators::TransposeParam param; + + param.x = &input; + param.output = &output; + param.axis = std::vector({0, 2, 3, 1}); + const DDim input_dim = + lite::DDim{std::vector({batch_size, ic, ih, iw})}; + input.Resize(input_dim); + const DDim output_dim = + lite::DDim{std::vector({batch_size, oc, oh, ow})}; + param.output->Resize(output_dim); + + LOG(INFO) << "prepare kernel SetParam------"; + kernel->SetParam(param); + + size_t input_image_width = iw * ((ic + 3) / 4); + size_t input_image_height = ih * batch_size; + + size_t output_image_width = ow * ((oc + 3) / 4); + size_t output_image_height = oh * batch_size; + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + + std::vector input_v(batch_size * ic * ih * iw); + + LOG(INFO) << "gen input ..."; + + float* input_v_data = &input_v[0]; + auto index = 0; + for (auto& i : input_v) { + i = index++; + } + + paddle::lite::CLImageConverterDefault default_convertor; + + std::vector x_image_data(input_image_width * input_image_height * + 4); // 4 : RGBA + + LOG(INFO) << "set mapped input ..."; + default_convertor.NCHWToImage(input_v_data, x_image_data.data(), input_dim); + + auto* input_image = input.mutable_data( + input_image_width, input_image_height, x_image_data.data()); + + LOG(INFO) << "prepare kernel ready"; + + LOG(INFO) << "mutable output ..."; + CLImageConverterDefault default_converter; + DDim out_image_shape = default_converter.InitImageDimInfoWith(output_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = output.mutable_data( + out_image_shape[0], out_image_shape[1]); + + LOG(INFO) << "kernel context ..."; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + std::unique_ptr transpose_context(new KernelContext); + context->As().CopySharedTo( + &(transpose_context->As())); + kernel->SetContext(std::move(transpose_context)); + + LOG(INFO) << "kernel launch ..."; + kernel->Launch(); + + CLRuntime::Global()->command_queue().finish(); + + half_t* out_image_data = new half_t[out_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + output.data(), + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + float* out_data = new float[out_image_shape.production() * 4]; + default_converter.ImageToNCHW( + out_image_data, out_data, out_image_shape, output_dim); + + // check output data + index = 0; + auto hxw = ih * iw; + auto cxhxw = ic * hxw; + for (auto n = 0; n < batch_size; n++) { + for (auto h = 0; h < ih; h++) { + for (auto w = 0; w < iw; w++) { + for (auto c = 0; c < ic; c++) { + auto input_index = n * cxhxw + c * hxw + h * iw + w; + auto input_value = input_v_data[input_index]; + auto output_value = out_data[index]; + auto abs_diff = abs(input_value - output_value); + auto relative_diff = COMPUTE_RELATIVE_DIFF(input_value, output_value); + EXPECT_EQ( + (relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + index++; + } + } + } + } +} + +TEST(transpose_opencl, compute) { + auto kernels = KernelRegistry::Global().Create("transpose", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + TestWithKernel(kernel); +} + +TEST(transpose2_opencl, compute) { + auto kernels = KernelRegistry::Global().Create("transpose2", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + TestWithKernel(kernel); +} + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(transpose, kOpenCL, kFP16, kImageDefault, image2d); -- GitLab