From e13bde79dfca90b6f7d3ec8f1f16899010192d1d Mon Sep 17 00:00:00 2001 From: xiaogang Date: Mon, 9 Mar 2020 16:25:40 +0800 Subject: [PATCH] Opencl eltwisesub (#3116) * feat: add opencl elementwise_sub op & ut --- .../cl_kernel/image/elementwise_sub_kernel.cl | 85 +++++ lite/kernels/opencl/CMakeLists.txt | 10 +- .../opencl/elementwise_sub_image_compute.cc | 173 +++++++++++ .../opencl/elementwise_sub_image_compute.h | 53 ++++ .../elementwise_sub_image_compute_test.cc | 292 ++++++++++++++++++ ...lementwise_sub_activation_image_compute.cc | 69 +++++ 6 files changed, 681 insertions(+), 1 deletion(-) create mode 100644 lite/backends/opencl/cl_kernel/image/elementwise_sub_kernel.cl create mode 100644 lite/kernels/opencl/elementwise_sub_image_compute.cc create mode 100644 lite/kernels/opencl/elementwise_sub_image_compute.h create mode 100644 lite/kernels/opencl/elementwise_sub_image_compute_test.cc create mode 100644 lite/kernels/opencl/fusion_elementwise_sub_activation_image_compute.cc diff --git a/lite/backends/opencl/cl_kernel/image/elementwise_sub_kernel.cl b/lite/backends/opencl/cl_kernel/image/elementwise_sub_kernel.cl new file mode 100644 index 0000000000..6ed6af298f --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/elementwise_sub_kernel.cl @@ -0,0 +1,85 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void elementwise_sub(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + + int2 coords; + coords.x = x; + coords.y = y; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords); + CL_DTYPE4 output = activation_type4(in - biase); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage,coords,output); + } + +__kernel void channel_sub(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage, + int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output = in - (CL_DTYPE4)(biase.x); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); + } + +__kernel void width_sub(__read_only image2d_t input, + __read_only image2d_t bias, + __write_only image2d_t outputImage, + int w) { + int x = get_global_id(0); + int y = get_global_id(1); + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords; + coords.x = x; + coords.y = y; + + int2 coords_bias; + coords_bias.x = x % w; + coords_bias.y = 0; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, coords); + CL_DTYPE4 biase = READ_IMG_TYPE(CL_DTYPE_CHAR, bias, sampler, coords_bias); + CL_DTYPE4 output; + + output.x = in.x - biase.x; + output.y = in.y - biase.x; + output.z = in.z - biase.x; + output.w = in.w - biase.x; + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, outputImage, coords, output); +} diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 5d00e05f69..b6a4059c06 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -9,10 +9,14 @@ set(cl_kernel_deps op_params cl_runtime cl_context cl_wrapper cl_target_wrapper ##################### # basic add_kernel(elementwise_add_opencl OPENCL basic SRCS elementwise_add_image_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(elementwise_sub_opencl OPENCL basic SRCS elementwise_sub_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(elementwise_mul_opencl OPENCL basic SRCS elementwise_mul_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(fusion_elementwise_add_activation_opencl OPENCL basic SRCS fusion_elementwise_add_activation_image_compute.cc DEPS elementwise_add_opencl ${cl_kernel_deps}) +add_kernel(fusion_elementwise_sub_activation_opencl + OPENCL basic SRCS fusion_elementwise_sub_activation_image_compute.cc + DEPS elementwise_sub_opencl ${cl_kernel_deps}) add_kernel(pool_opencl OPENCL basic SRCS pool_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(activation_opencl OPENCL basic SRCS activation_image_compute.cc DEPS ${cl_kernel_deps}) @@ -66,7 +70,11 @@ lite_cc_test(test_layout_image_opencl SRCS layout_image_compute_test.cc DEPS layout_opencl op_registry program context) lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc - DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context) + DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +lite_cc_test(test_elementwise_sub_image_opencl SRCS elementwise_sub_image_compute_test.cc + DEPS elementwise_sub_opencl fusion_elementwise_sub_activation_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) lite_cc_test(test_grid_sampler_image_opencl SRCS grid_sampler_image_compute_test.cc DEPS grid_sampler_opencl op_registry program context) diff --git a/lite/kernels/opencl/elementwise_sub_image_compute.cc b/lite/kernels/opencl/elementwise_sub_image_compute.cc new file mode 100644 index 0000000000..4cc7f21f8c --- /dev/null +++ b/lite/kernels/opencl/elementwise_sub_image_compute.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2019 PsublePsuble Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/opencl/elementwise_sub_image_compute.h" +#include +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/op_registry.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +void ElementwiseSubImageCompute::PrepareForRun() { + ele_param_ = param_.get_mutable(); + auto* x = ele_param_->X; + auto* y = ele_param_->Y; + auto axis = ele_param_->axis; + + if (y->dims().size() == 4) { + kernel_func_name_ = "elementwise_sub"; // y: ImageDefault + } else if (y->dims().size() == 1) { + if (axis == x->dims().size() - 1) { + kernel_func_name_ = "width_sub"; // y: ImageDefault + } else if (axis == x->dims().size() - 3) { + kernel_func_name_ = "channel_sub"; // y: ImageFolder + } else { + LOG(FATAL) << "ElementwiseSubImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + } else { + LOG(FATAL) << "ElementwiseSubImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + VLOG(4) << "kernel_func_name_:" << kernel_func_name_; + + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/elementwise_sub_kernel.cl", build_options_); +} + +void ElementwiseSubImageCompute::Run() { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto* x = ele_param_->X; + auto* y = ele_param_->Y; + auto* out = ele_param_->Out; + auto axis = ele_param_->axis; + + VLOG(4) << "x->target():" << TargetToStr(x->target()); + VLOG(4) << "y->target():" << TargetToStr(y->target()); + VLOG(4) << "out->target():" << TargetToStr(out->target()); + VLOG(4) << "x->dims():" << x->dims(); + VLOG(4) << "y->dims():" << y->dims(); + VLOG(4) << "out->dims():" << out->dims(); + VLOG(4) << "axis:" << axis; + + paddle::lite::CLImageConverterDefault default_convertor; + auto x_img_shape = default_convertor.InitImageDimInfoWith(x->dims()); // w, h + auto x_img_width = x_img_shape[0]; + auto x_img_height = x_img_shape[1]; + auto out_img_shape = + default_convertor.InitImageDimInfoWith(out->dims()); // w, h + auto y_img_shape = default_convertor.InitImageDimInfoWith(y->dims()); + + auto* x_img = x->data(); + auto* y_img = y->data(); + auto* out_img = out->mutable_data(out_img_shape[0], + out_img_shape[1]); + + VLOG(4) << "x_img_shape[w,h]:" << x_img_width << " " << x_img_height; + VLOG(4) << "y_img_shape[w,h]:" << y_img_shape[0] << " " << y_img_shape[1]; + VLOG(4) << "out_img_shape[w,h]:" << out_img_shape[0] << " " + << out_img_shape[1]; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + auto y_dims = y->dims(); + if (y_dims.size() == 4) { + cl_int status = kernel.setArg(arg_idx, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_img); + CL_CHECK_FATAL(status); + } else if (y_dims.size() == 1) { + if (axis == x->dims().size() - 1 || axis == x->dims().size() - 3) { + int tensor_w = x->dims()[x->dims().size() - 1]; + VLOG(4) << "tensor_w:" << tensor_w; + + cl_int status = kernel.setArg(arg_idx, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *y_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, static_cast(tensor_w)); + CL_CHECK_FATAL(status); + } else { + LOG(FATAL) << "ElementwiseSubImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + } else { + LOG(FATAL) << "ElementwiseSubImage doesn't support axis:" << axis + << ", x->dims().size():" << x->dims().size() + << ", y->dims.size():" << y->dims().size(); + } + + auto global_work_size = cl::NDRange{static_cast(x_img_width), + static_cast(x_img_height)}; + VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height; + auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); +} + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace ocl = paddle::lite::kernels::opencl; + +// TODO(ysh329): May need fix. +// "Y" may from constant value like conv bias (kARM, need do cl_image_converter +// on CPU); +// may from anther branch like "X" (kOpenCL, nothing to do). +// Consider 2 situations have different actions when pass running(pick kernel), +// set target of "Y" as kOpenCL temporarily. +REGISTER_LITE_KERNEL(elementwise_sub, + kOpenCL, + kFP16, + kImageDefault, + ocl::ElementwiseSubImageCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/elementwise_sub_image_compute.h b/lite/kernels/opencl/elementwise_sub_image_compute.h new file mode 100644 index 0000000000..48386b083e --- /dev/null +++ b/lite/kernels/opencl/elementwise_sub_image_compute.h @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PsublePsuble Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/core/kernel.h" +#include "lite/operators/op_params.h" +#include "lite/utils/cp_logging.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class ElementwiseSubImageCompute + : public KernelLite { + public: + using param_t = operators::ElementwiseParam; + + void PrepareForRun() override; + + void Run() override; + + std::string doc() const override { + return "ElementwiseSub using cl::Image2D, kFP16"; + } + + protected: + param_t* ele_param_{nullptr}; + std::string kernel_func_name_{"elementwise_sub"}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/opencl/elementwise_sub_image_compute_test.cc b/lite/kernels/opencl/elementwise_sub_image_compute_test.cc new file mode 100644 index 0000000000..0593747547 --- /dev/null +++ b/lite/kernels/opencl/elementwise_sub_image_compute_test.cc @@ -0,0 +1,292 @@ +// Copyright (c) 2019 PsublePsuble Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { + +template +void fill_data(dtype *x, const int length, int set_value = -1) { + if (set_value == -1) { + for (size_t idx = 0; idx < length; ++idx) { + x[idx] = idx; + } + } else if (set_value != -1) { + for (size_t idx = 0; idx < length; ++idx) { + x[idx] = set_value; + } + } +} + +template +void elementwise_compute_ref(const dtype *x_data, + const dtype *y_data, + dtype *out_data, + const DDim &x_dims, + const DDim &y_dims, + int axis, + const std::string elt_type, + bool use_relu = false) { + if (axis < 0) { + axis = x_dims.size() - y_dims.size(); + } + int batch = 1; + int channels = 1; + int num = 1; + for (int i = 0; i < axis; ++i) { + batch *= x_dims[i]; + } + for (int i = 0; i < y_dims.size(); ++i) { + channels *= y_dims[i]; + } + for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) { + num *= x_dims[i]; + } + VLOG(4) << "axis:" << axis; + VLOG(4) << "batch:" << batch; + VLOG(4) << "cahnnels:" << channels; + VLOG(4) << "num:" << num; + // do elementwise sub/sub/max/... + if (elt_type == "sub" && axis == 1 && y_dims.size() == 1) { + for (int i = 0; i < x_dims.production(); ++i) { + auto w = i % y_dims.production(); + out_data[i] = x_data[i] - y_data[w]; + } + } else if (elt_type == "sub") { + for (int i = 0; i < batch; ++i) { + for (int j = 0; j < channels; ++j) { + int offset = (i * channels + j) * num; + const dtype *din_ptr = x_data + offset; + const dtype diny_data = y_data[j]; + dtype *dout_ptr = out_data + offset; + for (int k = 0; k < num; ++k) { + *dout_ptr = *din_ptr - diny_data; + if (use_relu) { + *dout_ptr = std::max(*dout_ptr, static_cast(0)); + } + dout_ptr++; + din_ptr++; + } + } + } + } else { + LOG(FATAL) << "unsupported Elementwise type: " << elt_type << std::endl; + } +} + +// #define PRINT_RESULT +// image +TEST(elementwise_sub_image, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img on cpu) -> " + "elementwise_sub(img) -> " + "layout(img2buf on cpu) " + "-> host"; + + // elementwise_sub's 3 kernels selection routing strategy: + // -------------------------------------------------------- + // 1. elementwise_sub: Need y_dim.size() == 4 + // 2. elementwise_sub (used by fuse_elementwise_activation op): + // Need y_dim.size() == 4 && act_type == "relu" + // 3. width_sub: Need y_dim.size() == 1 && x_dim.size() == 4 && axis == + // 3 + // 4. channel_sub: Need y_dim.size() == 1 && x_dim.size() == 4 && axis == + // 1 + + // dims + const int n = 1; + const int c = 3; + const int h = 2; + const int w = 2; + + const DDim x_dim = DDim(std::vector{n, c, h, w}); + auto out_dim = x_dim; + // y_dim / axis / relu_flag + std::vector y_dim_v{DDim(std::vector{n, c, h, w}), + DDim(std::vector{n, c, h, w}), + DDim(std::vector{w}), + DDim(std::vector{w})}; + std::vector axis_v{-1, -1, 3, 1}; + std::vector relu_flag_v{false, true, false, false}; + CHECK(y_dim_v.size() == axis_v.size() && axis_v.size() == relu_flag_v.size()) + << "y_dim_v.size() == axis_v.size() == relu_flag_v.size() should be " + "same, and be corresponding " + "one by one"; + + // start loop + for (size_t case_idx = 0; case_idx < y_dim_v.size(); ++case_idx) { + auto y_dim = y_dim_v[case_idx]; + auto axis = axis_v[case_idx]; + auto relu_flag = relu_flag_v[case_idx]; + LOG(INFO) << "================== elementwise_sub, case_idx:" << case_idx + 1 + << "/" << y_dim_v.size() << " ==================="; + LOG(INFO) << "x_dim:" << x_dim; + LOG(INFO) << "y_dim:" << y_dim; + LOG(INFO) << "out_dim:" << out_dim; + LOG(INFO) << "axis:" << axis; + LOG(INFO) << "relu_flag:" << relu_flag; + + // tensor + VLOG(4) << "set tensors about op param"; + lite::Tensor elesub_x, elesub_y, elesub_out; + elesub_x.Resize(x_dim); + elesub_y.Resize(y_dim); + elesub_out.Resize(out_dim); + + // initialize tensors + VLOG(4) << "initialize tensors"; + paddle::lite::CLImageConverterDefault default_convertor; + // x + std::vector x_v(x_dim.production()); + fill_data(x_v.data(), x_v.size()); // fill with index value + auto x_img_shape = default_convertor.InitImageDimInfoWith(x_dim); // w, h + auto x_img_w = x_img_shape[0]; + auto x_img_h = x_img_shape[1]; + std::vector x_img_v(x_img_w * x_img_h * 4); // 4: RGBA + default_convertor.NCHWToImage(x_v.data(), x_img_v.data(), x_dim); + elesub_x.mutable_data( + x_img_w, x_img_h, x_img_v.data()); + + // y + std::vector y_v(y_dim.production()); + fill_data(y_v.data(), y_v.size()); // fill with index value + auto y_img_shape = default_convertor.InitImageDimInfoWith(y_dim); // w, h + auto y_img_w = y_img_shape[0]; + auto y_img_h = y_img_shape[1]; + std::vector y_img_v(y_img_shape[0] * y_img_shape[1] * + 4); // 4: RGBA + default_convertor.NCHWToImage(y_v.data(), y_img_v.data(), y_dim); + elesub_y.mutable_data( + y_img_w, y_img_h, y_img_v.data()); + + // out + auto out_img_shape = + default_convertor.InitImageDimInfoWith(out_dim); // w, h + auto out_img_w = out_img_shape[0]; + auto out_img_h = out_img_shape[1]; + elesub_out.mutable_data(out_img_w, out_img_h); + + std::vector out_img_v(out_img_w * out_img_h * 4); + fill_data( + out_img_v.data(), out_img_v.size(), 0); // fill with zero value + + std::vector out_v(out_dim.production()); + + // operator param + operators::FusionElementwiseActivationParam + fuseElesubParam; // enabled if relu_flag is true + fuseElesubParam.X = &elesub_x; + fuseElesubParam.Y = &elesub_y; + fuseElesubParam.Out = &elesub_out; + fuseElesubParam.axis = axis; + fuseElesubParam.act_type = relu_flag ? "relu" : ""; + + operators::ElementwiseParam elesubParam; + elesubParam.X = &elesub_x; + elesubParam.Y = &elesub_y; + elesubParam.Out = &elesub_out; + elesubParam.axis = axis; + + auto op_param = relu_flag ? fuseElesubParam : elesubParam; + + // set kernel + auto elesub_img_kernels = + KernelRegistry::Global().Create("elementwise_sub", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(elesub_img_kernels.empty()); + + auto elesub_img_kernel = std::move(elesub_img_kernels.front()); + VLOG(4) << "get elesub kernel: " << elesub_img_kernel->doc(); + + // set context and kernel args + VLOG(4) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + elesub_img_kernel->SetParam(op_param); + std::unique_ptr elesub_img_context(new KernelContext); + context->As().CopySharedTo( + &(elesub_img_context->As())); + elesub_img_kernel->SetContext(std::move(elesub_img_context)); + + // run kernel + VLOG(4) << "run kernel"; + elesub_img_kernel->Launch(); + + // download gpu result to cpu + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + TargetWrapperCL::ImgcpySync(out_img_v.data(), + elesub_out.data(), + out_img_w, + out_img_h, + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + default_convertor.ImageToNCHW( + out_img_v.data(), out_v.data(), out_img_shape, out_dim); + + // compute cpu reference + std::unique_ptr out_ref(new float[out_dim.production()]); + elementwise_compute_ref(x_v.data(), + y_v.data(), + out_ref.get(), + x_dim, + y_dim, + op_param.axis, + "sub", + relu_flag); + +#ifdef PRINT_RESULT // enable to check value of x and y + for (int eidx = 0; eidx < out_dim.production(); eidx++) { + auto value = out_v[eidx]; + auto ref_value = out_ref.get()[eidx]; + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / " + << out_dim.production() << ", x_v[" << eidx << "]:" << x_v[eidx] + << ", value[" << eidx << "]:" << value << ", ref_value[" << eidx + << "]:" << ref_value; + } + + for (int i = 0; i < y_v.size(); i++) { + LOG(INFO) << "y_v[" << i << "]:" << y_v[i]; + } +#endif + + for (int eidx = 0; eidx < out_dim.production(); eidx++) { + auto value = out_v[eidx]; + auto ref_value = out_ref.get()[eidx]; + EXPECT_NEAR(value, ref_value, 1e-6); + if (abs(value - ref_value) > 1e-6) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx << " / " + << out_dim.production() << ", value[" << eidx << "]:" << value + << ", ref_value[" << eidx << "]:" << ref_value; + break; + } + } + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(elementwise_sub, kOpenCL, kFP16, kImageDefault, def); +USE_LITE_KERNEL( + fusion_elementwise_sub_activation, kOpenCL, kFP16, kImageDefault, def); diff --git a/lite/kernels/opencl/fusion_elementwise_sub_activation_image_compute.cc b/lite/kernels/opencl/fusion_elementwise_sub_activation_image_compute.cc new file mode 100644 index 0000000000..c335d49f65 --- /dev/null +++ b/lite/kernels/opencl/fusion_elementwise_sub_activation_image_compute.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PsublePsuble Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/elementwise_sub_image_compute.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class FusionElementwiseSubActivationImageCompute + : public ElementwiseSubImageCompute { + public: + using param_t = operators::FusionElementwiseActivationParam; + + void PrepareForRun() override { + build_options_ += " -DRELU"; + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/elementwise_sub_kernel.cl", build_options_); + ele_param_ = param_.get_mutable(); + auto act_t = static_cast(ele_param_)->act_type; + VLOG(4) << "act: " << act_t; + if (act_t != "relu") { + LOG(FATAL) << "Unsupported Activation type: " << act_t; + } + } +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace ocl = paddle::lite::kernels::opencl; + +REGISTER_LITE_KERNEL(fusion_elementwise_sub_activation, + kOpenCL, + kFP16, + kImageDefault, + ocl::FusionElementwiseSubActivationImageCompute, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindInput("Y", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); -- GitLab