From 2da7cba95fd45d3041df280b3dfde883f38dd64e Mon Sep 17 00:00:00 2001 From: xiebaiyuan Date: Tue, 23 Jun 2020 16:17:50 +0800 Subject: [PATCH] [OPENCL] softmax with test, test=develop --- .../opencl/cl_kernel/image/softmax_kernel.cl | 58 +++++ lite/kernels/opencl/CMakeLists.txt | 4 + lite/kernels/opencl/softmax_image_compute.cc | 157 +++++++++++++ .../opencl/softmax_image_compute_test.cc | 213 ++++++++++++++++++ 4 files changed, 432 insertions(+) create mode 100644 lite/backends/opencl/cl_kernel/image/softmax_kernel.cl create mode 100644 lite/kernels/opencl/softmax_image_compute.cc create mode 100644 lite/kernels/opencl/softmax_image_compute_test.cc diff --git a/lite/backends/opencl/cl_kernel/image/softmax_kernel.cl b/lite/backends/opencl/cl_kernel/image/softmax_kernel.cl new file mode 100644 index 0000000000..b179741b4e --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/softmax_kernel.cl @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +__kernel void softmax(__read_only image2d_t input_image, + __write_only image2d_t output_image, + __private const int out_W) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2); + + const int in_c = out_c; + const int in_w = out_w; + const int in_nh = out_nh; + + int2 input_pos; + int2 output_pos; + + input_pos.x = in_c * out_W + in_w; + input_pos.y = in_nh; + + output_pos.x = out_c * out_W + out_w; + output_pos.y = out_nh; + + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + CL_DTYPE4 input_max = 0.0f; + CL_DTYPE4 input_tmp; + for (int i = 0; i < out_W; i++) { + input_tmp = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_c * out_W + i, in_nh)); + input_max = max(input_max, input_tmp); + } + + CL_DTYPE4 sum = (CL_DTYPE4)0.0f; + for (int i = 0; i < out_W; i++) { + input_tmp = READ_IMG_TYPE( + CL_DTYPE_CHAR, input_image, sampler, (int2)(in_c * out_W + i, in_nh)); + sum += exp(input_tmp - input_max); + } + + CL_DTYPE4 input = + READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos); + CL_DTYPE4 output = exp(input - input_max) / sum; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output); +} diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 600d0d2255..9920cde091 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -36,6 +36,7 @@ add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kerne add_kernel(box_coder_opencl OPENCL basic SRCS box_coder_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(pixel_shuffle_opencl OPENCL basic SRCS pixel_shuffle_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(expand_opencl OPENCL basic SRCS expand_image_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(softmax_opencl OPENCL basic SRCS softmax_image_compute.cc DEPS ${cl_kernel_deps}) # extra # wait to add ... @@ -82,6 +83,9 @@ lite_cc_test(test_pixel_shuffle_image_opencl SRCS pixel_shuffle_image_compute_te lite_cc_test(test_expand_image_opencl SRCS expand_image_compute_test.cc DEPS expand_opencl op_registry program context) +lite_cc_test(test_softmax_image_opencl SRCS softmax_image_compute_test.cc + DEPS softmax_opencl op_registry program context) + lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context) lite_cc_test(test_elementwise_sub_image_opencl SRCS elementwise_sub_image_compute_test.cc diff --git a/lite/kernels/opencl/softmax_image_compute.cc b/lite/kernels/opencl/softmax_image_compute.cc new file mode 100644 index 0000000000..d97d42e023 --- /dev/null +++ b/lite/kernels/opencl/softmax_image_compute.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class SoftmaxComputeImage2D : public KernelLite { + public: + using param_t = operators::SoftmaxParam; + + std::string doc() const override { + return "Softmax using cl::Image2D, kFP16"; + } + + void PrepareForRun() override { + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel(kernel_func_name_, + "image/softmax_kernel.cl", + build_options_, + time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_; + softmax_param_ = param_.get_mutable(); + auto x_dims = softmax_param_->x->dims(); + auto out_dims = softmax_param_->output->dims(); + VLOG(1) << "x_dims: " << x_dims; + VLOG(1) << "out_dims: " << out_dims; + VLOG(1) << "axis: " << softmax_param_->axis; + CHECK_EQ(out_dims.size(), 4) << "Softmax only support out_dims.size() == 4" + << out_dims; + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + // compute image shape + paddle::lite::CLImageConverterDefault default_convertor; + out_img_shape_ = default_convertor.InitImageDimInfoWith( + softmax_param_->output->dims()); + VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " " + << out_img_shape_[1]; + + // compute global work size + auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4); + size_t work_size_0 = image_width / out_dims[3]; + size_t work_size_1 = out_dims[3]; + size_t work_size_2 = out_dims[0] * out_dims[2]; + global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2}; + VLOG(1) << "global_work_size_: " << global_work_size_[0] << " " + << global_work_size_[1] << " " << global_work_size_[2]; + } + } + + void Run() override { + auto* x_img = softmax_param_->x->data(); + auto* out_img = softmax_param_->output->mutable_data( + out_img_shape_[0], out_img_shape_[1]); + auto out_dims = softmax_param_->output->dims(); + int out_w = out_dims[3]; + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + auto kernel = kernel_; + cl_int status; + status = kernel.setArg(0, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(1, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(2, out_w); + CL_CHECK_FATAL(status); + + status = EnqueueNDRangeKernel(context, + kernel, + cl::NullRange, + global_work_size_, + cl::NullRange, + nullptr, + event_); + CL_CHECK_FATAL(status); + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + private: + std::string kernel_func_name_{"softmax"}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::string time_stamp_{GetTimeStamp()}; + + param_t* softmax_param_{nullptr}; + cl::Kernel kernel_; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + DDim out_img_shape_ = DDim(std::vector( + {static_cast(1), static_cast(1)})); + cl::NDRange global_work_size_ = cl::NDRange{ + static_cast(1), static_cast(1), static_cast(1)}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(softmax, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::SoftmaxComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/softmax_image_compute_test.cc b/lite/kernels/opencl/softmax_image_compute_test.cc new file mode 100644 index 0000000000..c1fe9f7010 --- /dev/null +++ b/lite/kernels/opencl/softmax_image_compute_test.cc @@ -0,0 +1,213 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { + +template +void softmax_compute_ref(const operators::SoftmaxParam& param) { + const dtype* x_data = param.x->mutable_data(); + dtype* output_data = param.output->mutable_data(); + DDim x_dims = param.x->dims(); + ASSERT_EQ(x_dims.data(), param.output->dims().data()); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + dtype max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } + + offset = start; + dtype sum_data = (dtype)0; + for (int j = 0; j < axis_size; j++) { + output_data[offset] = exp(x_data[offset] - max_data); + sum_data += output_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + output_data[offset] /= sum_data; + offset += inner_num; + } + } +} + +TEST(softmax_image2d, compute) { +#if 1 + for (auto n : {1, 3}) { + for (auto c : {1, 4}) { + for (auto h : {5, 1}) { + for (auto w : {1, 6}) { + for (auto axis : {/*-2,*/ -1 /*, 0, 1, 2*/}) { +#else + for (auto n : {1, 3, 4, 11}) { + for (auto c : {1, 3, 11, 4}) { + for (auto h : {3, 1, 11, 4}) { + for (auto w : {1, 3, 4, 12}) { + for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) { +#endif + LOG(INFO) << "create kernel ..."; + auto kernels = + KernelRegistry::Global().Create("softmax", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + // prepare opencl kernel params + auto kernel = std::move(kernels.front()); + LOG(INFO) << "prepare to test kernel ====> " << kernel->doc(); + LOG(INFO) << n << c << h << w; + operators::SoftmaxParam param; + lite::Tensor x; + lite::Tensor output; + + operators::SoftmaxParam param_ref; + lite::Tensor x_ref; + lite::Tensor output_ref; + + auto in_dim = DDim(std::vector({n, c, h, w})); + auto out_dim = DDim(std::vector({n, c, h, w})); + x.Resize(in_dim); + x_ref.Resize(in_dim); + + output.Resize(out_dim); + output_ref.Resize(out_dim); + + param.x = &x; + param.axis = axis; + param.output = &output; + + param_ref.x = &x_ref; + param_ref.axis = axis; + param_ref.output = &output_ref; + auto* x_ref_data = x_ref.mutable_data(); + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr softmax_context(new KernelContext); + context->As().CopySharedTo( + &(softmax_context->As())); + + kernel->SetContext(std::move(softmax_context)); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-2, 2); + std::vector input_v(n * c * h * w); + + int index = 0; + for (auto& i : input_v) { + x_ref_data[index] = index; + i = index++; + } + VLOG(1) << "input_v ..... "; + for (size_t i = 0; i < input_v.size(); i++) { + VLOG(10) << input_v[i]; + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = + new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith( + DDim(std::vector({n, c, h, w}))); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * + 4); // 4 : RGBA + default_converter->NCHWToImage( + input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + VLOG(1) << "x_image_data ..... "; + for (size_t i = 0; i < x_image_data.size(); i++) { + VLOG(10) << Half2Float(x_image_data[i]); + } + DDim out_image_shape = + default_converter->InitImageDimInfoWith(out_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = output.mutable_data( + out_image_shape[0], out_image_shape[1]); + // run + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + + // handle output + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = + new half_t[out_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + VLOG(1) << "out_image_data ..... "; + for (size_t i = 0; i < out_image_shape.production() * 4; i++) { + VLOG(10) << Half2Float(out_image_data[i]); + } + std::vector out_data(out_image_shape.production() * 4); + default_converter->ImageToNCHW( + out_image_data, out_data.data(), out_image_shape, out_dim); + + VLOG(1) << "out_data ..... "; + for (int i = 0; i < out_dim.production(); i++) { + VLOG(10) << out_data[i]; + } + + auto* output_ref_data = output_ref.mutable_data(); + softmax_compute_ref(param_ref); + + for (int i = 0; i < output.dims().production(); i++) { + EXPECT_NEAR(out_data[i], output_ref_data[i], 1e-2); + } + } + } + } + } + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(softmax, kOpenCL, kFP16, kImageDefault, image2d); -- GitLab