diff --git a/lite/backends/opencl/cl_kernel/cl_common.h b/lite/backends/opencl/cl_kernel/cl_common.h index 8f60ea45031e0d5c08bd7f436fbeae9c1d752c7f..f193ab82d78fcd21165100658e9a0edefdbd5e0a 100644 --- a/lite/backends/opencl/cl_kernel/cl_common.h +++ b/lite/backends/opencl/cl_kernel/cl_common.h @@ -14,8 +14,17 @@ limitations under the License. */ #pragma once +///////////////////////////////// +// fp16 enabled, MAX_VALUE, MIN_VALUE +///////////////////////////////// #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define MAX_VALUE FLT_MAX +#define MIN_VALUE -FLT_MAX + +///////////////////////////////// +// CL_DTYPE_float / CL_DTYPE_half +///////////////////////////////// // Data type: pass one of macros on host: [CL_DTYPE_float, CL_DYPE_half] #ifdef CL_DTYPE_float #define CL_DTYPE float @@ -27,14 +36,23 @@ limitations under the License. */ #define CL_DTYPE_CHAR h #endif +///////////////////////////////// +// GET_VEC_TYPE +///////////////////////////////// // Note: macro name replacement need twice parser #define GET_VEC_TYPE(type__, size__) type__##size__ #define VECTORIZED_TYPE(type__, size__) GET_VEC_TYPE(type__, size__) #define CL_DTYPE4 VECTORIZED_TYPE(CL_DTYPE, 4) +///////////////////////////////// +// CONVERT_TYPE_TO +///////////////////////////////// #define _CONVERT_TYPE_TO(value, type) convert_##type(value) #define CONVERT_TYPE_TO(value, type) _CONVERT_TYPE_TO(value, type) +///////////////////////////////// +// WRITE_IMG_TYPE / READ_IMG_TYPE +///////////////////////////////// #define _WRITE_IMG_TYPE(type_char, img, pos, value) \ write_image##type_char(img, pos, value) #define WRITE_IMG_TYPE(type_char, img, pos, value) \ @@ -45,6 +63,9 @@ limitations under the License. */ #define READ_IMG_TYPE(type_char, img, sampler, pos) \ _READ_IMG_TYPE(type_char, img, sampler, pos) +///////////////////////////////// +// activation / activation_type4 +///////////////////////////////// inline CL_DTYPE activation(CL_DTYPE in #ifdef PRELU , @@ -61,6 +82,7 @@ inline CL_DTYPE activation(CL_DTYPE in #endif return output; } + inline CL_DTYPE4 activation_type4(CL_DTYPE4 in #ifdef PRELU , diff --git a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl index 8bb7be6a42627a02a7fe1b6accd19fdf23dbe34e..1e3586b7fde8d79fe49327185c623ac613cd080d 100755 --- a/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/depthwise_conv2d_kernel.cl @@ -15,7 +15,7 @@ limitations under the License. */ #include -__kernel void depth_conv_3x3(__private const int global_size_dim0, +__kernel void depth_conv2d_3x3(__private const int global_size_dim0, __private const int global_size_dim1, __private const int global_size_dim2, __read_only image2d_t input, @@ -172,7 +172,7 @@ __kernel void depth_conv_3x3(__private const int global_size_dim0, -__kernel void depth_conv_3x3s1(__private const int ou_ch_blk, +__kernel void depth_conv2d_3x3s1(__private const int ou_ch_blk, __private const int ou_w_blk, __private const int ou_nh, __read_only image2d_t input, diff --git a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl index 4a7c53c980697f6b413bdbf63fc68039a5bfcb3a..775166261d01dc639cd5af8cee49f7e7fb30cb19 100644 --- a/lite/backends/opencl/cl_kernel/image/pool_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/pool_kernel.cl @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#define MIN_VALUE -FLT_MAX __kernel void pool_max(__read_only image2d_t input, __write_only image2d_t output, diff --git a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl index a99ac79d32bcedb48354d2e179ef6c8c1ff7f997..43a27067c2f2c418d314f9bce95bccbbb51a9be0 100644 --- a/lite/backends/opencl/cl_kernel/image/relu_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/relu_kernel.cl @@ -24,7 +24,7 @@ __kernel void relu(__read_only image2d_t input, CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - CL_DTYPE4 in = read_imagef(input, sampler, (int2)(x, y)); + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); in = max((CL_DTYPE4)(0.0f), in); - write_imagef(output, (int2)(x, y), in); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); } diff --git a/lite/backends/opencl/target_wrapper.cc b/lite/backends/opencl/target_wrapper.cc index 0dece6c582e9839df55c797346a9e6219e291241..310567baa539697f6a67b59f6c0e5f29ce46a80e 100644 --- a/lite/backends/opencl/target_wrapper.cc +++ b/lite/backends/opencl/target_wrapper.cc @@ -24,6 +24,8 @@ static cl_channel_type GetCLChannelType(const PrecisionType type) { switch (type) { case PRECISION(kFloat): return CL_FLOAT; + case PRECISION(kFP16): + return CL_HALF_FLOAT; case PRECISION(kInt32): return CL_SIGNED_INT32; case PRECISION(kInt8): @@ -79,11 +81,11 @@ void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, return cl_image; } -template <> -void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, - const size_t cl_image2d_height, - void *host_ptr) { - cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kInt8))); +template <> // use int16_t represents half float +void *TargetWrapperCL::MallocImage(const size_t cl_image2d_width, + const size_t cl_image2d_height, + void *host_ptr) { + cl::ImageFormat img_format(CL_RGBA, GetCLChannelType(PRECISION(kFP16))); cl_int status; cl::Image2D *cl_image = new cl::Image2D(CLRuntime::Global()->context(), diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index 1c7db871c7b525d6e4944fd0d669e81bcaff7f2a..ecfdcf3d1107953f1c41ea57b6f12187b29686c6 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -104,6 +104,12 @@ const cl::Image2D *TensorLite::data() const { if (nullptr == buffer_->data()) return nullptr; return static_cast(buffer_->data()); } + +template <> // use int16_t represent half float +const cl::Image2D *TensorLite::data() const { + if (nullptr == buffer_->data()) return nullptr; + return static_cast(buffer_->data()); +} #endif } // namespace lite diff --git a/lite/core/tensor.h b/lite/core/tensor.h index ca2e0e9a9850a93498e9dd986af33f725343d25b..a1141c613e29326a5f9ffb2fdc1427e3fbe84481 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -253,6 +253,9 @@ bool TensorCompareWith(const TensorT &a, const TensorT &b) { #ifdef LITE_WITH_OPENCL template <> const cl::Image2D *TensorLite::data() const; + +template <> // use int16_t represent half float +const cl::Image2D *TensorLite::data() const; #endif } // namespace lite diff --git a/lite/kernels/opencl/depthwise_conv2d_compute.cc b/lite/kernels/opencl/depthwise_conv2d_compute.cc index ac1b1e715da55ee32699908e3deb091dac7f5de6..5d573c14f7f4ea9b3768b5c8fdfe30c1b6c84c99 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute.cc @@ -30,6 +30,10 @@ class DepthwiseConv2dCompute public: using param_t = operators::ConvParam; + std::string doc() const override { + return "DepthwiseConv2d using cl::Buffer, kFloat"; + } + void PrepareForRun() override { const auto& param = *param_.get_mutable(); if (param.fuse_relu) { @@ -110,16 +114,22 @@ class DepthwiseConv2dCompute } private: - std::string kernel_func_name_{"depthwise_conv2d"}; + std::string kernel_func_name_{"depthwise_conv2d_3x3"}; std::string build_options_{"-DCL_DTYPE=float"}; std::shared_ptr event_{new cl::Event}; }; class DepthwiseConv2dComputeFP16Image - : public KernelLite { + : public KernelLite { public: using param_t = operators::ConvParam; + std::string doc() const override { + return "DepthwiseConv2d using cl::Image2D/kImageDefault, kFP16"; + } + void PrepareForRun() override { const auto& param = *param_.get_mutable(); if (param.fuse_relu) { @@ -143,16 +153,16 @@ class DepthwiseConv2dComputeFP16Image auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); - auto* input_img = param.x->data(); - auto* filter_img = param.filter->data(); + auto* input_img = param.x->data(); + auto* filter_img = param.filter->data(); auto* bias_img = param.bias == nullptr ? static_cast(nullptr) - : param.bias->data(); + : param.bias->data(); auto image_shape = InitImageDimInfoWith(output_dims); - auto* output_img = param.output->mutable_data( + auto* output_img = param.output->mutable_data( image_shape["width"], image_shape["height"]); STL::stringstream kernel_key; @@ -164,19 +174,19 @@ class DepthwiseConv2dComputeFP16Image int nh = output_dims[0] * output_dims[2]; auto global_work_size = cl::NDRange(c_block, w, nh); - LOG(INFO) << "setArg"; - LOG(INFO) << "c_block = " << c_block; - LOG(INFO) << "w = " << w; - LOG(INFO) << "nh = " << nh; + VLOG(4) << "setArg"; + VLOG(4) << "c_block = " << c_block; + VLOG(4) << "w = " << w; + VLOG(4) << "nh = " << nh; - LOG(INFO) << "strides = " << strides[0]; - LOG(INFO) << "offset = " << offset; - LOG(INFO) << "dilations = " << dilations[0]; - LOG(INFO) << "input_c_block = " << input_c_block; - LOG(INFO) << "x_dims[3] = " << x_dims[3]; - LOG(INFO) << "x_dims[2] = " << x_dims[2]; - LOG(INFO) << "output_dims[3] = " << output_dims[3]; - LOG(INFO) << "output_dims[2] = " << output_dims[2]; + VLOG(4) << "strides = " << strides[0]; + VLOG(4) << "offset = " << offset; + VLOG(4) << "dilations = " << dilations[0]; + VLOG(4) << "input_c_block = " << input_c_block; + VLOG(4) << "x_dims[3] = " << x_dims[3]; + VLOG(4) << "x_dims[2] = " << x_dims[2]; + VLOG(4) << "output_dims[3] = " << output_dims[3]; + VLOG(4) << "output_dims[2] = " << output_dims[2]; cl_int status; int arg_idx = 0; @@ -221,16 +231,22 @@ class DepthwiseConv2dComputeFP16Image } private: - std::string kernel_func_name_{"depth_conv_3x3"}; - std::string build_options_{"-DCL_DTYPE_float"}; + std::string kernel_func_name_{"depth_conv2d_3x3"}; + std::string build_options_{"-DCL_DTYPE_half"}; std::shared_ptr event_{new cl::Event}; }; class DepthwiseConv2d3x3s1ComputeFP16Image - : public KernelLite { + : public KernelLite { public: using param_t = operators::ConvParam; + std::string doc() const override { + return "DepthwiseConv2d3x3s1 using cl::Image2D/kImageDefault, kFP16"; + } + void PrepareForRun() override { const auto& param = *param_.get_mutable(); if (param.fuse_relu) { @@ -252,16 +268,16 @@ class DepthwiseConv2d3x3s1ComputeFP16Image auto& context = ctx_->As(); CHECK(context.cl_context() != nullptr); - auto* input_img = param.x->data(); - auto* filter_img = param.filter->data(); + auto* input_img = param.x->data(); + auto* filter_img = param.filter->data(); auto* bias_img = param.bias == nullptr ? static_cast(nullptr) - : param.bias->data(); + : param.bias->data(); auto image_shape = InitImageDimInfoWith(output_dims); - auto* output_img = param.output->mutable_data( + auto* output_img = param.output->mutable_data( image_shape["width"], image_shape["height"]); STL::stringstream kernel_key; @@ -320,8 +336,8 @@ class DepthwiseConv2d3x3s1ComputeFP16Image } private: - std::string kernel_func_name_{"depth_conv_3x3s1"}; - std::string build_options_{"-DCL_DTYPE_float"}; + std::string kernel_func_name_{"depth_conv2d_3x3s1"}; + std::string build_options_{"-DCL_DTYPE_half"}; std::shared_ptr event_{new cl::Event}; }; @@ -345,24 +361,24 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, REGISTER_LITE_KERNEL( depthwise_conv2d, kOpenCL, - kFloat, - kNHWC, + kFP16, + kImageDefault, paddle::lite::kernels::opencl::DepthwiseConv2dComputeFP16Image, image2d) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) + PRECISION(kFP16), + DATALAYOUT(kImageNW))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNHWC))}) + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) .Finalize(); diff --git a/lite/kernels/opencl/depthwise_conv2d_compute_test.cc b/lite/kernels/opencl/depthwise_conv2d_compute_test.cc index 6ab78d4f254fd593ffb421c8b4c25d256b5ba8bc..c52aa87a73c8f9cbd91851c96162cde817f299b4 100644 --- a/lite/kernels/opencl/depthwise_conv2d_compute_test.cc +++ b/lite/kernels/opencl/depthwise_conv2d_compute_test.cc @@ -90,7 +90,7 @@ void depth_conv(const T* input_data, } } -TEST(depthwise_conv2d, compute_buffer) { +TEST(depthwise_conv2d_buffer_fp32, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create("depthwise_conv2d", TARGET(kOpenCL), @@ -177,12 +177,12 @@ TEST(depthwise_conv2d, compute_buffer) { TargetWrapperCL::Unmap(input_data, mapped_input); } -TEST(depthwise_conv2d, compute_image2d) { +TEST(depthwise_conv2d_image2d_fp16, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create("depthwise_conv2d", TARGET(kOpenCL), - PRECISION(kFloat), - DATALAYOUT(kNHWC)); + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); ASSERT_FALSE(kernels.empty()); auto kernel = std::move(kernels.front()); @@ -231,7 +231,7 @@ TEST(depthwise_conv2d, compute_image2d) { 4); // 4 : RGBA default_converter->NCHWToImage( input_v.data(), input_image_data.data(), input.dims()); - auto* input_image = input.mutable_data( + auto* input_image = input.mutable_data( input_image_shape[0], input_image_shape[1], input_image_data.data()); LOG(INFO) << "prepare kernel"; @@ -244,7 +244,7 @@ TEST(depthwise_conv2d, compute_image2d) { 4); // 4 : RGBA nw_converter->NCHWToImage( filter_v.data(), filter_image_data.data(), filter.dims()); - auto* filter_image = filter.mutable_data( + auto* filter_image = filter.mutable_data( filter_image_shape[0], filter_image_shape[1], filter_image_data.data()); LOG(INFO) << "launch"; @@ -253,13 +253,13 @@ TEST(depthwise_conv2d, compute_image2d) { default_converter->InitImageDimInfoWith(output.dims()); LOG(INFO) << "output_image_shape = " << output_image_shape[0] << " " << output_image_shape[1]; - auto* output_image = output.mutable_data( + auto* output_image = output.mutable_data( output_image_shape[0], output_image_shape[1]); kernel->Launch(); auto* wait_list = context->As().cl_wait_list(); - auto* out_ptr = param.output->data(); + auto* out_ptr = param.output->data(); auto it = wait_list->find(out_ptr); if (it != wait_list->end()) { VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; @@ -308,4 +308,4 @@ TEST(depthwise_conv2d, compute_image2d) { } // namespace paddle USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFloat, kNHWC, image2d); +USE_LITE_KERNEL(depthwise_conv2d, kOpenCL, kFP16, kImageDefault, image2d); diff --git a/lite/kernels/opencl/pool_compute.cc b/lite/kernels/opencl/pool_compute.cc index 8cdc127a375f1a537e64b54317562ce38a697c2a..fca2cbe96d56b65e5f33acacff20c781b3400ed0 100644 --- a/lite/kernels/opencl/pool_compute.cc +++ b/lite/kernels/opencl/pool_compute.cc @@ -31,6 +31,8 @@ class PoolCompute public: using param_t = operators::PoolParam; + std::string doc() const override { return "Pool using cl::Buffer, kFloat"; } + void PrepareForRun() override { const auto& param = *param_.get_mutable(); kernel_func_name_ += param.pooling_type; @@ -114,15 +116,18 @@ class PoolCompute private: std::string kernel_func_name_{"pool_"}; - std::string build_options_{"-DCL_DTYPE=float"}; + std::string build_options_{"-DCL_DTYPE_float"}; std::shared_ptr event_{new cl::Event}; }; -class PoolComputeImage2D - : public KernelLite { +class PoolComputeImage2D : public KernelLite { public: using param_t = operators::PoolParam; + std::string doc() const override { return "Pool using cl::Image2D, kFloat"; } + void PrepareForRun() override { const auto& param = *param_.get_mutable(); kernel_func_name_ += param.pooling_type; @@ -237,15 +242,15 @@ REGISTER_LITE_KERNEL(pool2d, REGISTER_LITE_KERNEL(pool2d, kOpenCL, kFloat, - kNHWC, + kImageDefault, paddle::lite::kernels::opencl::PoolComputeImage2D, image2d) .BindInput("X", {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFloat), - DATALAYOUT(kNHWC))}) + DATALAYOUT(kImageDefault))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kOpenCL), PRECISION(kFloat), - DATALAYOUT(kNHWC))}) + DATALAYOUT(kImageDefault))}) .Finalize(); diff --git a/lite/kernels/opencl/pool_compute_test.cc b/lite/kernels/opencl/pool_compute_test.cc index 269c31d3bd15f7204f735efc4ef2a7e5bd8abd20..f97c758469ece8f2eaf59ebbb1b5065d71641616 100644 --- a/lite/kernels/opencl/pool_compute_test.cc +++ b/lite/kernels/opencl/pool_compute_test.cc @@ -73,15 +73,14 @@ void pool_avg(const int padding_height, } } -TEST(pool2d, compute_buffer) { +TEST(pool2d_buffer_fp32, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create( "pool2d", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); ASSERT_FALSE(kernels.empty()); auto kernel = std::move(kernels.front()); - - LOG(INFO) << "get kernel"; + LOG(INFO) << "get kernel:" << kernel->doc(); lite::Tensor x, out; operators::PoolParam param; @@ -143,15 +142,15 @@ TEST(pool2d, compute_buffer) { TargetWrapperCL::Unmap(out_data, mapped_out); } -TEST(pool2d, compute_image2d) { +TEST(pool2d_image2d_fp32, compute) { LOG(INFO) << "to get kernel ..."; auto kernels = KernelRegistry::Global().Create( - "pool2d", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNHWC)); + "pool2d", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kImageDefault)); ASSERT_FALSE(kernels.empty()); auto kernel = std::move(kernels.front()); - LOG(INFO) << "get kernel"; + LOG(INFO) << "get kernel:" << kernel->doc(); lite::Tensor x, out; operators::PoolParam param; @@ -194,14 +193,14 @@ TEST(pool2d, compute_image2d) { default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim); auto* x_image = x.mutable_data( x_image_shape[0], x_image_shape[1], x_image_data.data()); - LOG(INFO) << "x_image" << x_image; + LOG(INFO) << "x_image:" << x_image; DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim); LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " << out_image_shape[1]; auto* out_image = out.mutable_data(out_image_shape[0], out_image_shape[1]); - LOG(INFO) << "out_image" << out_image; + LOG(INFO) << "out_image:" << out_image; kernel->Launch(); auto* wait_list = context->As().cl_wait_list(); @@ -241,4 +240,4 @@ TEST(pool2d, compute_image2d) { } // namespace paddle USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def); -USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNHWC, image2d); +USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kImageDefault, image2d); diff --git a/lite/kernels/opencl/relu_compute.cc b/lite/kernels/opencl/relu_compute.cc index addf628bb0f0bd5c5031f2a8ee19ba167d7032ac..c5272fa14ac1af25ca44d611a59ed04016d771d0 100644 --- a/lite/kernels/opencl/relu_compute.cc +++ b/lite/kernels/opencl/relu_compute.cc @@ -29,7 +29,7 @@ class ReluCompute public: using param_t = operators::ActivationParam; - std::string doc() const override { return "Relu using cl::Buffer"; } + std::string doc() const override { return "Relu using cl::Buffer, kFloat"; } void PrepareForRun() override { auto& context = ctx_->As(); context.cl_context()->AddKernel( @@ -85,7 +85,7 @@ class ReluComputeFloatImageDefault using param_t = operators::ActivationParam; std::string doc() const override { - return "Relu using cl::Image2D(ImageDefault/RGBA)"; + return "Relu using cl::Image2D(ImageDefault/RGBA), kFloat"; } void PrepareForRun() override { @@ -146,6 +146,80 @@ class ReluComputeFloatImageDefault std::shared_ptr event_{new cl::Event}; }; +class ReluComputeFP16ImageDefault + : public KernelLite { + public: + using param_t = operators::ActivationParam; + + std::string doc() const override { + return "Relu using cl::Image2D(ImageDefault/RGBA), kFP16"; + } + + void PrepareForRun() override { + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/relu_kernel.cl", build_options_); + } + + void Run() override { + auto& param = *param_.get_mutable(); + const auto& x_dims = param.X->dims(); + auto* x_buf = + param.X->data(); // use int16_t represents half float + auto image_shape = InitImageDimInfoWith(x_dims); + auto* out_buf = + param.Out->mutable_data( // use int16_t + // represents half float + image_shape["width"], + image_shape["height"]); + const auto& y_dims = param.Out->dims(); // useless: check dim only + + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + cl_int status = kernel.setArg(arg_idx, *x_buf); + CL_CHECK_FATAL(status); + status = kernel.setArg(++arg_idx, *out_buf); + CL_CHECK_FATAL(status); + + VLOG(4) << TargetToStr(param.X->target()); + VLOG(4) << TargetToStr(param.Out->target()); + VLOG(4) << "image_shape(w,h):" << image_shape["width"] << " " + << image_shape["height"]; + VLOG(4) << "x_dims[" << x_dims.size() << "D]:" << x_dims[0] << " " + << x_dims[1] << " " << x_dims[2] << " " << x_dims[3]; + VLOG(4) << "y_dims[" << y_dims.size() << "D]:" << y_dims[0] << " " + << y_dims[1] << " " << y_dims[2] << " " << y_dims[3]; + + auto global_work_size = + cl::NDRange{static_cast(image_shape["width"]), + static_cast(image_shape["height"])}; + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + // TODO(ysh329): io_copy(device->host) jammed if emplace to `cl_wait_list` + // context.cl_wait_list()->emplace(out_buf, event_); + context.cl_context()->GetCommandQueue().finish(); + } + + private: + std::string kernel_func_name_{"relu"}; + std::string build_options_{"-DCL_DTYPE_half -DRELU"}; + std::shared_ptr event_{new cl::Event}; +}; + } // namespace opencl } // namespace kernels } // namespace lite @@ -177,3 +251,19 @@ REGISTER_LITE_KERNEL( PRECISION(kFloat), DATALAYOUT(kImageDefault))}) .Finalize(); + +REGISTER_LITE_KERNEL(relu, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::ReluComputeFP16ImageDefault, + ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/relu_compute_test.cc b/lite/kernels/opencl/relu_compute_test.cc index 45a60f81e11d2f1e3cdac3e58b950368d90aeb80..3745f3a8f7d8ab1d5e8f49d1c2b1ba8ff0c0a30d 100644 --- a/lite/kernels/opencl/relu_compute_test.cc +++ b/lite/kernels/opencl/relu_compute_test.cc @@ -93,7 +93,7 @@ TEST(opencl_relu_buffer, compute) { // #define LOOP_TEST // #define PRINT_RESULT -TEST(relu_image2d, compute) { +TEST(relu_image2d_fp32, compute) { LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu(img) -> " "layout(img2buf) " "-> host"; @@ -247,13 +247,170 @@ TEST(relu_image2d, compute) { #endif } +TEST(relu_image2d_fp16, compute) { + LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu(img) -> " + "layout(img2buf) " + "-> host"; + +#ifdef LOOP_TEST + for (int n = 1; n <= 100; n += 33) { + for (auto c : {1, 3}) { + for (int h = 12; h <= 100; h += 13) { + for (int w = 12; w <= 100; w += 25) { +#else + const int n = 1; + const int c = 2; + const int h = 3; + const int w = 4; +#endif // LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " " + << h << " " << w << " ========"; + // set layout kernels + auto buf_to_img_kernels = + KernelRegistry::Global().Create("layout", + TARGET(kOpenCL), + PRECISION(kAny), + DATALAYOUT(kImageDefault)); + auto img_to_buf_kernels = KernelRegistry::Global().Create( + "layout", TARGET(kOpenCL), PRECISION(kAny), DATALAYOUT(kNCHW)); + auto relu_img_kernels = + KernelRegistry::Global().Create("relu", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(buf_to_img_kernels.empty()); + ASSERT_FALSE(relu_img_kernels.empty()); + + auto buf_to_img_kernel = std::move(buf_to_img_kernels.front()); + auto img_to_buf_kernel = std::move(img_to_buf_kernels.front()); + auto relu_img_kernel = std::move(relu_img_kernels.front()); + LOG(INFO) << "get 1st kernel: " << buf_to_img_kernel->doc(); + LOG(INFO) << "get 2nd kernel: " << img_to_buf_kernel->doc(); + LOG(INFO) << "get 3rd kernel: " << relu_img_kernel->doc(); + + // set tensors about op param + LOG(INFO) << "set tensors about op param"; + // layout(buf->img): x -> relu_in + // relu(img): relu_in -> relu_out + // layout(img->buf): relu_out -> y + lite::Tensor x, y, relu_in, relu_out, y_ref; + operators::LayoutParam BufferToImageParam; + operators::LayoutParam ImageToBufferParam; + BufferToImageParam.x = &x; + BufferToImageParam.y = &relu_in; + ImageToBufferParam.x = &relu_out; + ImageToBufferParam.y = &y; + operators::ActivationParam ReluParam; + ReluParam.X = &relu_in; + ReluParam.Out = &relu_out; + + const DDim x_dim = DDim(std::vector{n, c, h, w}); + x.Resize(x_dim); + y.Resize(x_dim); + relu_in.Resize(x_dim); + relu_out.Resize(x_dim); + y_ref.Resize(x_dim); + auto relu_image2d_shape = + paddle::lite::kernels::opencl::InitImageDimInfoWith(x_dim); + + // initialize tensors + LOG(INFO) << "initialize tensors"; + auto *x_data = x.mutable_data(TARGET(kOpenCL)); + auto *y_data = y.mutable_data(TARGET(kOpenCL)); + auto *y_data_ref = y_ref.mutable_data(TARGET(kARM)); + auto *mapped_x = static_cast(TargetWrapperCL::Map( + x_data, 0, sizeof(float) * x_dim.production())); + auto *mapped_y = static_cast(TargetWrapperCL::Map( + y_data, 0, sizeof(float) * x_dim.production())); + for (int i = 0; i < x_dim.production(); ++i) { + mapped_x[i] = static_cast(i) - x_dim.production() / 2; + mapped_y[i] = static_cast(0); + } + auto *relu_in_data = relu_in.mutable_data( + relu_image2d_shape["width"], relu_image2d_shape["height"]); + auto *relu_out_data = relu_out.mutable_data( + relu_image2d_shape["width"], relu_image2d_shape["height"]); + + // set context and kernel args + LOG(INFO) << "set context and kernel args"; + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + buf_to_img_kernel->SetParam(BufferToImageParam); + std::unique_ptr buf_to_img_context(new KernelContext); + context->As().CopySharedTo( + &(buf_to_img_context->As())); + buf_to_img_kernel->SetContext(std::move(buf_to_img_context)); + + img_to_buf_kernel->SetParam(ImageToBufferParam); + std::unique_ptr img_to_buf_context(new KernelContext); + context->As().CopySharedTo( + &(img_to_buf_context->As())); + img_to_buf_kernel->SetContext(std::move(img_to_buf_context)); + + relu_img_kernel->SetParam(ReluParam); + std::unique_ptr relu_img_context(new KernelContext); + context->As().CopySharedTo( + &(relu_img_context->As())); + relu_img_kernel->SetContext(std::move(relu_img_context)); + + // run kernels + LOG(INFO) << "run kernel: buf_to_img_kernel"; + buf_to_img_kernel->Launch(); + LOG(INFO) << "run kernel: relu_img_kernel"; + relu_img_kernel->Launch(); + LOG(INFO) << "run kernel: img_to_buf_kernel"; + img_to_buf_kernel->Launch(); + + // compute ref cpu + relu_compute_ref(mapped_x, x_dim, y_data_ref); +// result +#ifdef PRINT_RESULT + LOG(INFO) << "---- print kernel result (input -> output) ----"; + for (int eidx = 0; eidx < x_dim.production(); ++eidx) { + std::cout << mapped_x[eidx] << " -> " << mapped_y[eidx] + << std::endl; + } +#endif // PRINT_RESULT + + // check result: compare kernel output and cpu output(y_data_ref) + for (int eidx = 0; eidx < x_dim.production(); eidx++) { + EXPECT_NEAR(y_data_ref[eidx], mapped_y[eidx], 1e-6); + if (abs(y_data_ref[eidx] - mapped_y[eidx]) > 1e-6) { + LOG(INFO) << "1st diff in this case at eidx[from 0]:" << eidx + << " / " << x_dim.production() << ", y_data_ref[" + << eidx << "]:" << y_data_ref[eidx] << ", mapped_y[" + << eidx << "]:" << mapped_y[eidx]; + break; + } + } + + // free + LOG(INFO) << "free: unmap x, y"; + TargetWrapperCL::Unmap(x_data, mapped_x); + TargetWrapperCL::Unmap(y_data, mapped_y); +#ifdef LOOP_TEST + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + } // namespace lite } // namespace paddle // relu buffer // USE_LITE_KERNEL(relu, kOpenCL, kFloat, kNCHW, def); -// relu image2d +// relu image2d fp32 USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW); USE_LITE_KERNEL(relu, kOpenCL, kFloat, kImageDefault, ImageDefault); + +// relu image2d fp16 +USE_LITE_KERNEL(relu, kOpenCL, kFP16, kImageDefault, ImageDefault);