From 44b09240435f6f48de3b23e63fa6df9c6a5341b3 Mon Sep 17 00:00:00 2001 From: ZhenWang Date: Tue, 25 Jun 2019 20:33:16 +0800 Subject: [PATCH] add pool2d opencl support. --- paddle/fluid/lite/api/paddle_use_kernels.h | 1 + .../fluid/lite/kernels/host/feed_compute.cc | 2 - .../fluid/lite/kernels/opencl/CMakeLists.txt | 9 +- .../opencl/elementwise_add_compute_test.cc | 12 +- .../fluid/lite/kernels/opencl/pool_compute.cc | 67 +++++ .../lite/kernels/opencl/pool_compute_test.cc | 123 +++++++++ paddle/fluid/lite/opencl/CMakeLists.txt | 4 +- paddle/fluid/lite/opencl/cl_caller.cc | 58 +++++ paddle/fluid/lite/opencl/cl_caller.h | 5 + paddle/fluid/lite/opencl/cl_image_converter.h | 41 +-- paddle/fluid/lite/opencl/cl_test.cc | 79 +++++- .../opencl/{cl_wrapper.cc => cl_wrapper.cxx} | 234 +++++++++--------- 12 files changed, 487 insertions(+), 148 deletions(-) create mode 100644 paddle/fluid/lite/kernels/opencl/pool_compute.cc create mode 100644 paddle/fluid/lite/kernels/opencl/pool_compute_test.cc rename paddle/fluid/lite/opencl/{cl_wrapper.cc => cl_wrapper.cxx} (85%) diff --git a/paddle/fluid/lite/api/paddle_use_kernels.h b/paddle/fluid/lite/api/paddle_use_kernels.h index 797acd7aa90..b5a727d53f0 100644 --- a/paddle/fluid/lite/api/paddle_use_kernels.h +++ b/paddle/fluid/lite/api/paddle_use_kernels.h @@ -64,4 +64,5 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); #ifdef LITE_WITH_OPENCL USE_LITE_KERNEL(elementwise_add, kOpenCL, kFloat, kNCHW, def); +USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def); #endif diff --git a/paddle/fluid/lite/kernels/host/feed_compute.cc b/paddle/fluid/lite/kernels/host/feed_compute.cc index 7bbd648c20d..1c944e5e02d 100644 --- a/paddle/fluid/lite/kernels/host/feed_compute.cc +++ b/paddle/fluid/lite/kernels/host/feed_compute.cc @@ -31,8 +31,6 @@ class FeedCompute VLOG(4) << "col " << param.col; const lite::Tensor &feed_item = (*param.feed_list)[0]; param.out->ShareDataWith(feed_item); - VLOG(4) << "FEED input " << feed_item << " col " << param.col; - VLOG(4) << "FEED output " << *param.out; } }; diff --git a/paddle/fluid/lite/kernels/opencl/CMakeLists.txt b/paddle/fluid/lite/kernels/opencl/CMakeLists.txt index 179882628af..8cdaf3389d7 100644 --- a/paddle/fluid/lite/kernels/opencl/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/opencl/CMakeLists.txt @@ -5,12 +5,19 @@ endif() set(cl_kernel_deps op_params_lite cl_caller cl_engine cl_context cl_wrapper) cc_library(elementwise_add_opencl SRCS elementwise_add_compute.cc DEPS ${cl_kernel_deps}) +cc_library(pool_opencl SRCS pool_compute.cc DEPS ${cl_kernel_deps}) lite_cc_test(test_elementwise_add_opencl SRCS elementwise_add_compute_test.cc DEPS elementwise_add_opencl op_registry_lite program_lite context_lite ) +lite_cc_test(test_pool_opencl SRCS pool_compute_test.cc DEPS pool_opencl + op_registry_lite program_lite + context_lite + ) + set(opencl_kernels elementwise_add_opencl - CACHE INTERNAL "") + pool_opencl + CACHE INTERNAL "opencl_kernels") diff --git a/paddle/fluid/lite/kernels/opencl/elementwise_add_compute_test.cc b/paddle/fluid/lite/kernels/opencl/elementwise_add_compute_test.cc index f82d8477d55..1040c8bd547 100644 --- a/paddle/fluid/lite/kernels/opencl/elementwise_add_compute_test.cc +++ b/paddle/fluid/lite/kernels/opencl/elementwise_add_compute_test.cc @@ -40,23 +40,23 @@ TEST(elementwise_add, init) { kernel->SetParam(param); kernel->SetContext(std::move(context)); - X.Resize({1, 1, 1, 10}); - Y.Resize({1, 1, 1, 10}); - Out.Resize({1, 1, 1, 10}); + X.Resize({4, 3, 10, 10}); + Y.Resize({4, 3, 10, 10}); + Out.Resize({4, 3, 10, 10}); auto* x_data = X.mutable_data(); auto* y_data = Y.mutable_data(); auto* out_data = Out.mutable_data(); - for (int i = 0; i < 10; i++) { + for (int i = 0; i < 4 * 3 * 10 * 10; i++) { x_data[i] = 1.1 * i; y_data[i] = 2.3 * i; } kernel->Launch(); - for (int i = 0; i < 10; i++) { - EXPECT_NEAR(out_data[i], 3.4 * i, 1e-6); + for (int i = 0; i < 4 * 3 * 10 * 10; i++) { + EXPECT_NEAR(out_data[i], static_cast(3.4 * i), 1e-6); } } diff --git a/paddle/fluid/lite/kernels/opencl/pool_compute.cc b/paddle/fluid/lite/kernels/opencl/pool_compute.cc new file mode 100644 index 00000000000..2514e8d929b --- /dev/null +++ b/paddle/fluid/lite/kernels/opencl/pool_compute.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/operators/op_params.h" +// NOTE ugly here, hide these. +#include "paddle/fluid/lite/opencl/cl_caller.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class PoolCompute + : public KernelLite { + public: + using param_t = operators::PoolParam; + + void Run() override { + auto& param = *param_.get_mutable(); + auto& in_dims = param.x->dims(); + auto& out_dims = param.output->dims(); + const std::string pooling_type = param.pooling_type; + bool global_pooling = param.global_pooling; + std::vector& paddings = param.paddings; + std::vector& strides = param.strides; + std::vector& ksize = param.ksize; + if (global_pooling) { + for (size_t i = 0; i < ksize.size(); ++i) { + paddings[i] = 0; + ksize[i] = static_cast(in_dims[i + 2]); + } + } + + auto& context = ctx_->As(); + CHECK(context.cl_helper() != nullptr); + + pool(context.cl_helper(), pooling_type, paddings[0], paddings[1], + strides[0], strides[1], ksize[0], ksize[1], + static_cast(param.x->raw_data()), in_dims, + param.output->mutable_data(), out_dims); + } +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, + paddle::lite::kernels::opencl::PoolCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .Finalize(); diff --git a/paddle/fluid/lite/kernels/opencl/pool_compute_test.cc b/paddle/fluid/lite/kernels/opencl/pool_compute_test.cc new file mode 100644 index 00000000000..fde3caae84e --- /dev/null +++ b/paddle/fluid/lite/kernels/opencl/pool_compute_test.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { + +void pool_avg(const int padding_height, const int padding_width, + const int stride_height, const int stride_width, + const int ksize_height, const int ksize_width, + const float* input_data, const DDim& in_dim, float* output_data, + const DDim& out_dim) { + const int batch_size = in_dim[0]; + const int input_height = in_dim[2]; + const int input_width = in_dim[3]; + const int output_channels = out_dim[1]; + const int output_height = out_dim[2]; + const int output_width = out_dim[3]; + + const size_t input_spatial_size = input_height * input_width; + const size_t output_spatial_size = output_height * output_width; + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + int channel = i * output_channels + c; + const float* input_ptr = input_data + channel * input_spatial_size; + float* output_ptr = output_data + channel * output_spatial_size; + + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + float val = 0.f; + int count = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + val += input_ptr[h * input_width + w]; + ++count; + } + } + output_ptr[ph * output_width + pw] = + (count > 0) ? val * (1.f / count) : 0.f; + } + } + } + } +} + +TEST(pool2d, init) { + LOG(INFO) << "to get kernel ..."; + auto kernels = KernelRegistry::Global().Create( + "pool2d", TARGET(kOpenCL), PRECISION(kFloat), DATALAYOUT(kNCHW)); + ASSERT_FALSE(kernels.empty()); + + auto kernel = std::move(kernels.front()); + + LOG(INFO) << "get kernel"; + + lite::Tensor x, out; + operators::PoolParam param; + param.x = &x; + param.output = &out; + param.global_pooling = true; + param.pooling_type = "avg"; + param.paddings = std::vector{0, 0}; + param.strides = std::vector{1, 1}; + param.ksize = std::vector{7, 7}; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + kernel->SetContext(std::move(context)); + + const DDim in_dim = DDim(std::vector{4, 1024, 7, 7}); + const DDim out_dim = DDim(std::vector{4, 1024, 1, 1}); + x.Resize(in_dim); + out.Resize(out_dim); + + auto* x_data = x.mutable_data(); + auto* out_data = out.mutable_data(); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-5, 5); + + for (int i = 0; i < 4 * 1024 * 7 * 7; i++) { + x_data[i] = dist(engine); + } + + kernel->Launch(); + + std::unique_ptr out_ref(new float[4 * 1024 * 1 * 1]); + pool_avg(0, 0, 1, 1, 7, 7, x_data, in_dim, out_ref.get(), out_dim); + + for (int i = 0; i < 4 * 1024 * 1 * 1; i++) { + EXPECT_NEAR(out_data[i], out_ref[i], 1e-6); + } +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(pool2d, kOpenCL, kFloat, kNCHW, def); diff --git a/paddle/fluid/lite/opencl/CMakeLists.txt b/paddle/fluid/lite/opencl/CMakeLists.txt index 2017346f751..2d77ff27681 100644 --- a/paddle/fluid/lite/opencl/CMakeLists.txt +++ b/paddle/fluid/lite/opencl/CMakeLists.txt @@ -2,8 +2,8 @@ if (NOT LITE_WITH_OPENCL) return() endif() -cc_library(cl_wrapper SRCS cl_wrapper.cc) -cc_library(cl_tool SRCS cl_tool.cc) +cc_library(cl_wrapper SRCS cl_wrapper.cxx) +cc_library(cl_tool SRCS cl_tool.cc DEPS cl_wrapper) target_compile_options(cl_tool BEFORE PUBLIC -Wno-ignored-qualifiers) cc_library(cl_engine SRCS cl_engine.cc DEPS cl_tool) cc_library(cl_context SRCS cl_context.cc DEPS cl_engine) diff --git a/paddle/fluid/lite/opencl/cl_caller.cc b/paddle/fluid/lite/opencl/cl_caller.cc index a56540feb72..50394df3883 100644 --- a/paddle/fluid/lite/opencl/cl_caller.cc +++ b/paddle/fluid/lite/opencl/cl_caller.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/lite/opencl/cl_helper.h" #include "paddle/fluid/lite/opencl/cl_image.h" #include "paddle/fluid/lite/opencl/cl_tool.h" +#include "paddle/fluid/lite/utils/string.h" namespace paddle { namespace lite { @@ -94,5 +95,62 @@ void elementwise_add(CLHelper* helper, const float* in, const DDim& in_dim, CopyImageData(helper, out_image, out); } +void pool(CLHelper* helper, const std::string pooling_type, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int ksize_h, const int ksize_w, const float* in, + const DDim& in_dim, float* out, const DDim& out_dim) { + auto kernel = + helper->GetKernel(string_format("pool_%s", pooling_type.c_str())); + CLImage in_image; + in_image.set_tensor_data(in, in_dim); + in_image.InitNormalCLImage(helper->OpenCLContext()); + VLOG(3) << " --- Inpu image: " << in_image << " --- "; + CLImage out_image; + out_image.InitEmptyImage(helper->OpenCLContext(), out_dim); + auto global_work_size = helper->DefaultWorkSize(out_image); + auto* in_converter = + dynamic_cast(in_image.image_converter()); + auto* out_converter = + dynamic_cast(out_image.image_converter()); + const int in_height = in_converter->HeightOfOneBlock(); + const int in_width = in_converter->WidthOfOneBlock(); + const int out_height = out_converter->HeightOfOneBlock(); + const int out_width = out_converter->WidthOfOneBlock(); + cl_int status; + status = kernel.setArg(0, in_height); + CL_CHECK_ERRORS(status); + status = kernel.setArg(1, in_width); + CL_CHECK_ERRORS(status); + status = kernel.setArg(2, out_height); + CL_CHECK_ERRORS(status); + status = kernel.setArg(3, out_width); + CL_CHECK_ERRORS(status); + status = kernel.setArg(4, pad_h); + CL_CHECK_ERRORS(status); + status = kernel.setArg(5, pad_w); + CL_CHECK_ERRORS(status); + status = kernel.setArg(6, stride_h); + CL_CHECK_ERRORS(status); + status = kernel.setArg(7, stride_w); + CL_CHECK_ERRORS(status); + status = kernel.setArg(8, ksize_h); + CL_CHECK_ERRORS(status); + status = kernel.setArg(9, ksize_w); + CL_CHECK_ERRORS(status); + status = kernel.setArg(10, *in_image.cl_image()); + CL_CHECK_ERRORS(status); + status = kernel.setArg(11, *out_image.cl_image()); + CL_CHECK_ERRORS(status); + + status = helper->OpenCLCommandQueue().enqueueNDRangeKernel( + kernel, cl::NullRange, global_work_size, cl::NullRange, nullptr, nullptr); + CL_CHECK_ERRORS(status); + + status = helper->OpenCLCommandQueue().finish(); + CL_CHECK_ERRORS(status); + VLOG(3) << " --- Out image: " << out_image << " --- "; + CopyImageData(helper, out_image, out); +} + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/opencl/cl_caller.h b/paddle/fluid/lite/opencl/cl_caller.h index a55724b5cef..0d53574e17a 100644 --- a/paddle/fluid/lite/opencl/cl_caller.h +++ b/paddle/fluid/lite/opencl/cl_caller.h @@ -31,5 +31,10 @@ void elementwise_add(CLHelper* helper, const float* in, const DDim& in_dim, const float* bias, const DDim& bias_dim, float* out, const DDim& out_dim); +void pool(CLHelper* helper, const std::string pooling_type, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int ksize_h, const int ksize_w, const float* in, + const DDim& in_dim, float* out, const DDim& out_dim); + } // namespace lite } // namespace paddle diff --git a/paddle/fluid/lite/opencl/cl_image_converter.h b/paddle/fluid/lite/opencl/cl_image_converter.h index 874f292e0f4..1eab93f61be 100644 --- a/paddle/fluid/lite/opencl/cl_image_converter.h +++ b/paddle/fluid/lite/opencl/cl_image_converter.h @@ -33,18 +33,19 @@ class CLImageConverterBase { class CLImageConverterDefault : public CLImageConverterBase { public: - DDim InitImageDimInfoWith(const DDim &tensor_dim); - void NCHWToImage(float *nchw, float *image, const DDim &tensor_dim); + DDim InitImageDimInfoWith(const DDim &tensor_dim) override; + void NCHWToImage(float *nchw, float *image, const DDim &tensor_dim) override; void ImageToNCHW(float *image, float *tensor, const DDim &image_dim, - const DDim &tensor_dim); + const DDim &tensor_dim) override; }; class CLImageConverterFolder : public CLImageConverterBase { public: - DDim InitImageDimInfoWith(const DDim &tensor_dim); - void NCHWToImage(float *tensor, float *image, const DDim &tensor_dim); + DDim InitImageDimInfoWith(const DDim &tensor_dim) override; + void NCHWToImage(float *tensor, float *image, + const DDim &tensor_dim) override; void ImageToNCHW(float *image, float *tensor, const DDim &image_dim, - const DDim &tensor_dim); + const DDim &tensor_dim) override; /* * width of original tensor @@ -66,10 +67,11 @@ class CLImageConverterFolder : public CLImageConverterBase { class CLImageConverterNormal : public CLImageConverterBase { public: - DDim InitImageDimInfoWith(const DDim &tensor_dim); - void NCHWToImage(float *tensor, float *image, const DDim &tensor_dim); + DDim InitImageDimInfoWith(const DDim &tensor_dim) override; + void NCHWToImage(float *tensor, float *image, + const DDim &tensor_dim) override; void ImageToNCHW(float *image, float *tensor, const DDim &image_dim, - const DDim &tensor_dim); + const DDim &tensor_dim) override; /* * width of original tensor @@ -90,24 +92,27 @@ class CLImageConverterNormal : public CLImageConverterBase { }; class CLImageConverterNWBlock : public CLImageConverterBase { - DDim InitImageDimInfoWith(const DDim &tensor_dim); - void NCHWToImage(float *tensor, float *image, const DDim &tensor_dim); + DDim InitImageDimInfoWith(const DDim &tensor_dim) override; + void NCHWToImage(float *tensor, float *image, + const DDim &tensor_dim) override; void ImageToNCHW(float *image, float *tensor, const DDim &image_dim, - const DDim &tensor_dim); + const DDim &tensor_dim) override; }; class CLImageConverterDWBlock : public CLImageConverterBase { - DDim InitImageDimInfoWith(const DDim &tensor_dim); - void NCHWToImage(float *tensor, float *image, const DDim &tensor_dim); + DDim InitImageDimInfoWith(const DDim &tensor_dim) override; + void NCHWToImage(float *tensor, float *image, + const DDim &tensor_dim) override; void ImageToNCHW(float *image, float *tensor, const DDim &image_dim, - const DDim &tensor_dim); + const DDim &tensor_dim) override; }; class CLImageConverterWinoTransWeight : public CLImageConverterBase { public: - DDim InitImageDimInfoWith(const DDim &tensor_dim); - void NCHWToImage(float *tensor, float *image, const DDim &tensor_dim); + DDim InitImageDimInfoWith(const DDim &tensor_dim) override; + void NCHWToImage(float *tensor, float *image, + const DDim &tensor_dim) override; void ImageToNCHW(float *image, float *tensor, const DDim &image_dim, - const DDim &tensor_dim); + const DDim &tensor_dim) override; }; } // namespace lite diff --git a/paddle/fluid/lite/opencl/cl_test.cc b/paddle/fluid/lite/opencl/cl_test.cc index 57192b79d72..ea02b0c4fed 100644 --- a/paddle/fluid/lite/opencl/cl_test.cc +++ b/paddle/fluid/lite/opencl/cl_test.cc @@ -160,12 +160,11 @@ TEST(cl_test, channel_add_test) { for (int i = 0; i < 4 * 16 * 256 * 512; i += stride) { std::cout << out[i] << " "; } + std::cout << std::endl; for (int i = 0; i < 4 * 16 * 256 * 512; i++) { EXPECT_NEAR(out[i], out_ref[i], 1e-6); } - - std::cout << std::endl; } TEST(cl_test, elementwise_add_test) { @@ -205,12 +204,86 @@ TEST(cl_test, elementwise_add_test) { for (int i = 0; i < 4 * 16 * 256 * 512; i += stride) { std::cout << out[i] << " "; } + std::cout << std::endl; for (int i = 0; i < 4 * 16 * 256 * 512; i++) { EXPECT_NEAR(out[i], out_ref[i], 1e-6); } +} - std::cout << std::endl; +void pool_avg(const int padding_height, const int padding_width, + const int stride_height, const int stride_width, + const int ksize_height, const int ksize_width, + const float* input_data, const DDim& in_dim, float* output_data, + const DDim& out_dim) { + const int batch_size = in_dim[0]; + const int input_height = in_dim[2]; + const int input_width = in_dim[3]; + const int output_channels = out_dim[1]; + const int output_height = out_dim[2]; + const int output_width = out_dim[3]; + + const size_t input_spatial_size = input_height * input_width; + const size_t output_spatial_size = output_height * output_width; + + for (int i = 0; i < batch_size; i++) { + for (int c = 0; c < output_channels; ++c) { + int channel = i * output_channels + c; + const float* input_ptr = input_data + channel * input_spatial_size; + float* output_ptr = output_data + channel * output_spatial_size; + + for (int ph = 0; ph < output_height; ++ph) { + int hstart = ph * stride_height - padding_height; + int hend = std::min(hstart + ksize_height, input_height); + hstart = std::max(hstart, 0); + for (int pw = 0; pw < output_width; ++pw) { + int wstart = pw * stride_width - padding_width; + int wend = std::min(wstart + ksize_width, input_width); + wstart = std::max(wstart, 0); + + float val = 0.f; + int count = 0; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + val += input_ptr[h * input_width + w]; + ++count; + } + } + output_ptr[ph * output_width + pw] = + (count > 0) ? val * (1.f / count) : 0.f; + } + } + } + } +} + +TEST(cl_test, pool_test) { + std::default_random_engine engine; + std::uniform_real_distribution dist(-5, 5); + + const DDim in_dim = DDim(std::vector{4, 1024, 7, 7}); + std::unique_ptr in_data(new float[4 * 1024 * 7 * 7]); + for (int i = 0; i < 4 * 1024 * 7 * 7; i++) { + in_data[i] = dist(engine); + } + + const DDim out_dim = DDim(std::vector{4, 1024, 1, 1}); + std::unique_ptr out(new float[4 * 1024 * 1 * 1]); + std::unique_ptr out_ref(new float[4 * 1024 * 1 * 1]); + + bool status = InitOpenCLEngine(FLAGS_cl_path); + CHECK(status) << "Fail to initialize OpenCL engine."; + std::unique_ptr context(new CLContext); + std::unique_ptr helper(new CLHelper(context.get())); + helper->AddKernel("pool_max", "pool_kernel.cl"); + helper->AddKernel("pool_avg", "pool_kernel.cl"); + pool(helper.get(), "avg", 0, 0, 1, 1, 7, 7, in_data.get(), in_dim, out.get(), + out_dim); + pool_avg(0, 0, 1, 1, 7, 7, in_data.get(), in_dim, out_ref.get(), out_dim); + + for (int i = 0; i < 4 * 1024 * 1 * 1; i++) { + EXPECT_NEAR(out[i], out_ref[i], 1e-6); + } } } // namespace lite diff --git a/paddle/fluid/lite/opencl/cl_wrapper.cc b/paddle/fluid/lite/opencl/cl_wrapper.cxx similarity index 85% rename from paddle/fluid/lite/opencl/cl_wrapper.cc rename to paddle/fluid/lite/opencl/cl_wrapper.cxx index 52c68bdc969..4979d03e504 100644 --- a/paddle/fluid/lite/opencl/cl_wrapper.cc +++ b/paddle/fluid/lite/opencl/cl_wrapper.cxx @@ -1,16 +1,18 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright 2018 The MACE Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// This file is borrowed from MACE, and we will refactor it +// in the near future. #include #include @@ -157,58 +159,58 @@ class OpenCLLibrary final { using clGetImageInfoFunc = cl_int (*)(cl_mem, cl_image_info, size_t, void *, size_t *); -#define PADDLE_CL_DEFINE_FUNC_PTR(func) func##Func func = nullptr - - PADDLE_CL_DEFINE_FUNC_PTR(clGetPlatformIDs); - PADDLE_CL_DEFINE_FUNC_PTR(clGetPlatformInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clBuildProgram); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueNDRangeKernel); - PADDLE_CL_DEFINE_FUNC_PTR(clSetKernelArg); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseKernel); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateProgramWithSource); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateBuffer); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateImage); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateImage2D); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateUserEvent); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainKernel); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateKernel); - PADDLE_CL_DEFINE_FUNC_PTR(clGetProgramInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clFlush); - PADDLE_CL_DEFINE_FUNC_PTR(clFinish); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseProgram); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainContext); - PADDLE_CL_DEFINE_FUNC_PTR(clGetContextInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateProgramWithBinary); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateCommandQueue); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateCommandQueueWithProperties); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseCommandQueue); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueMapBuffer); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueMapImage); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainProgram); - PADDLE_CL_DEFINE_FUNC_PTR(clGetProgramBuildInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueReadBuffer); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueReadImage); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueWriteBuffer); - PADDLE_CL_DEFINE_FUNC_PTR(clWaitForEvents); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseEvent); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateContext); - PADDLE_CL_DEFINE_FUNC_PTR(clCreateContextFromType); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseContext); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainCommandQueue); - PADDLE_CL_DEFINE_FUNC_PTR(clEnqueueUnmapMemObject); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainMemObject); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseMemObject); - PADDLE_CL_DEFINE_FUNC_PTR(clGetDeviceInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clGetDeviceIDs); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainDevice); - PADDLE_CL_DEFINE_FUNC_PTR(clReleaseDevice); - PADDLE_CL_DEFINE_FUNC_PTR(clRetainEvent); - PADDLE_CL_DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clGetEventInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clGetEventProfilingInfo); - PADDLE_CL_DEFINE_FUNC_PTR(clGetImageInfo); - -#undef PADDLE_CL_DEFINE_FUNC_PTR +#define CL_DEFINE_FUNC_PTR(func) func##Func func = nullptr + + CL_DEFINE_FUNC_PTR(clGetPlatformIDs); + CL_DEFINE_FUNC_PTR(clGetPlatformInfo); + CL_DEFINE_FUNC_PTR(clBuildProgram); + CL_DEFINE_FUNC_PTR(clEnqueueNDRangeKernel); + CL_DEFINE_FUNC_PTR(clSetKernelArg); + CL_DEFINE_FUNC_PTR(clReleaseKernel); + CL_DEFINE_FUNC_PTR(clCreateProgramWithSource); + CL_DEFINE_FUNC_PTR(clCreateBuffer); + CL_DEFINE_FUNC_PTR(clCreateImage); + CL_DEFINE_FUNC_PTR(clCreateImage2D); + CL_DEFINE_FUNC_PTR(clCreateUserEvent); + CL_DEFINE_FUNC_PTR(clRetainKernel); + CL_DEFINE_FUNC_PTR(clCreateKernel); + CL_DEFINE_FUNC_PTR(clGetProgramInfo); + CL_DEFINE_FUNC_PTR(clFlush); + CL_DEFINE_FUNC_PTR(clFinish); + CL_DEFINE_FUNC_PTR(clReleaseProgram); + CL_DEFINE_FUNC_PTR(clRetainContext); + CL_DEFINE_FUNC_PTR(clGetContextInfo); + CL_DEFINE_FUNC_PTR(clCreateProgramWithBinary); + CL_DEFINE_FUNC_PTR(clCreateCommandQueue); + CL_DEFINE_FUNC_PTR(clCreateCommandQueueWithProperties); + CL_DEFINE_FUNC_PTR(clReleaseCommandQueue); + CL_DEFINE_FUNC_PTR(clEnqueueMapBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueMapImage); + CL_DEFINE_FUNC_PTR(clRetainProgram); + CL_DEFINE_FUNC_PTR(clGetProgramBuildInfo); + CL_DEFINE_FUNC_PTR(clEnqueueReadBuffer); + CL_DEFINE_FUNC_PTR(clEnqueueReadImage); + CL_DEFINE_FUNC_PTR(clEnqueueWriteBuffer); + CL_DEFINE_FUNC_PTR(clWaitForEvents); + CL_DEFINE_FUNC_PTR(clReleaseEvent); + CL_DEFINE_FUNC_PTR(clCreateContext); + CL_DEFINE_FUNC_PTR(clCreateContextFromType); + CL_DEFINE_FUNC_PTR(clReleaseContext); + CL_DEFINE_FUNC_PTR(clRetainCommandQueue); + CL_DEFINE_FUNC_PTR(clEnqueueUnmapMemObject); + CL_DEFINE_FUNC_PTR(clRetainMemObject); + CL_DEFINE_FUNC_PTR(clReleaseMemObject); + CL_DEFINE_FUNC_PTR(clGetDeviceInfo); + CL_DEFINE_FUNC_PTR(clGetDeviceIDs); + CL_DEFINE_FUNC_PTR(clRetainDevice); + CL_DEFINE_FUNC_PTR(clReleaseDevice); + CL_DEFINE_FUNC_PTR(clRetainEvent); + CL_DEFINE_FUNC_PTR(clGetKernelWorkGroupInfo); + CL_DEFINE_FUNC_PTR(clGetEventInfo); + CL_DEFINE_FUNC_PTR(clGetEventProfilingInfo); + CL_DEFINE_FUNC_PTR(clGetImageInfo); + +#undef CL_DEFINE_FUNC_PTR private: void *handle_ = nullptr; @@ -285,7 +287,7 @@ void *OpenCLLibrary::LoadFromPath(const std::string &path) { return nullptr; } -#define PADDLE_CL_ASSIGN_FROM_DLSYM(func) \ +#define CL_ASSIGN_FROM_DLSYM(func) \ do { \ void *ptr = dlsym(handle, #func); \ if (ptr == nullptr) { \ @@ -296,56 +298,56 @@ void *OpenCLLibrary::LoadFromPath(const std::string &path) { VLOG(3) << "Loaded " << #func << " from " << path; \ } while (false) - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetPlatformIDs); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetPlatformInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clBuildProgram); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueNDRangeKernel); - PADDLE_CL_ASSIGN_FROM_DLSYM(clSetKernelArg); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseKernel); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateProgramWithSource); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateBuffer); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateImage); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateImage2D); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateUserEvent); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainKernel); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateKernel); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetProgramInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clFlush); - PADDLE_CL_ASSIGN_FROM_DLSYM(clFinish); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseProgram); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainContext); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetContextInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateProgramWithBinary); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateCommandQueue); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateCommandQueueWithProperties); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseCommandQueue); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueMapBuffer); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueMapImage); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainProgram); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetProgramBuildInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueReadBuffer); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueReadImage); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueWriteBuffer); - PADDLE_CL_ASSIGN_FROM_DLSYM(clWaitForEvents); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseEvent); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateContext); - PADDLE_CL_ASSIGN_FROM_DLSYM(clCreateContextFromType); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseContext); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainCommandQueue); - PADDLE_CL_ASSIGN_FROM_DLSYM(clEnqueueUnmapMemObject); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainMemObject); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseMemObject); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetDeviceInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetDeviceIDs); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainDevice); - PADDLE_CL_ASSIGN_FROM_DLSYM(clReleaseDevice); - PADDLE_CL_ASSIGN_FROM_DLSYM(clRetainEvent); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetEventInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetEventProfilingInfo); - PADDLE_CL_ASSIGN_FROM_DLSYM(clGetImageInfo); - -#undef PADDLE_CL_ASSIGN_FROM_DLSYM + CL_ASSIGN_FROM_DLSYM(clGetPlatformIDs); + CL_ASSIGN_FROM_DLSYM(clGetPlatformInfo); + CL_ASSIGN_FROM_DLSYM(clBuildProgram); + CL_ASSIGN_FROM_DLSYM(clEnqueueNDRangeKernel); + CL_ASSIGN_FROM_DLSYM(clSetKernelArg); + CL_ASSIGN_FROM_DLSYM(clReleaseKernel); + CL_ASSIGN_FROM_DLSYM(clCreateProgramWithSource); + CL_ASSIGN_FROM_DLSYM(clCreateBuffer); + CL_ASSIGN_FROM_DLSYM(clCreateImage); + CL_ASSIGN_FROM_DLSYM(clCreateImage2D); + CL_ASSIGN_FROM_DLSYM(clCreateUserEvent); + CL_ASSIGN_FROM_DLSYM(clRetainKernel); + CL_ASSIGN_FROM_DLSYM(clCreateKernel); + CL_ASSIGN_FROM_DLSYM(clGetProgramInfo); + CL_ASSIGN_FROM_DLSYM(clFlush); + CL_ASSIGN_FROM_DLSYM(clFinish); + CL_ASSIGN_FROM_DLSYM(clReleaseProgram); + CL_ASSIGN_FROM_DLSYM(clRetainContext); + CL_ASSIGN_FROM_DLSYM(clGetContextInfo); + CL_ASSIGN_FROM_DLSYM(clCreateProgramWithBinary); + CL_ASSIGN_FROM_DLSYM(clCreateCommandQueue); + CL_ASSIGN_FROM_DLSYM(clCreateCommandQueueWithProperties); + CL_ASSIGN_FROM_DLSYM(clReleaseCommandQueue); + CL_ASSIGN_FROM_DLSYM(clEnqueueMapBuffer); + CL_ASSIGN_FROM_DLSYM(clEnqueueMapImage); + CL_ASSIGN_FROM_DLSYM(clRetainProgram); + CL_ASSIGN_FROM_DLSYM(clGetProgramBuildInfo); + CL_ASSIGN_FROM_DLSYM(clEnqueueReadBuffer); + CL_ASSIGN_FROM_DLSYM(clEnqueueReadImage); + CL_ASSIGN_FROM_DLSYM(clEnqueueWriteBuffer); + CL_ASSIGN_FROM_DLSYM(clWaitForEvents); + CL_ASSIGN_FROM_DLSYM(clReleaseEvent); + CL_ASSIGN_FROM_DLSYM(clCreateContext); + CL_ASSIGN_FROM_DLSYM(clCreateContextFromType); + CL_ASSIGN_FROM_DLSYM(clReleaseContext); + CL_ASSIGN_FROM_DLSYM(clRetainCommandQueue); + CL_ASSIGN_FROM_DLSYM(clEnqueueUnmapMemObject); + CL_ASSIGN_FROM_DLSYM(clRetainMemObject); + CL_ASSIGN_FROM_DLSYM(clReleaseMemObject); + CL_ASSIGN_FROM_DLSYM(clGetDeviceInfo); + CL_ASSIGN_FROM_DLSYM(clGetDeviceIDs); + CL_ASSIGN_FROM_DLSYM(clRetainDevice); + CL_ASSIGN_FROM_DLSYM(clReleaseDevice); + CL_ASSIGN_FROM_DLSYM(clRetainEvent); + CL_ASSIGN_FROM_DLSYM(clGetKernelWorkGroupInfo); + CL_ASSIGN_FROM_DLSYM(clGetEventInfo); + CL_ASSIGN_FROM_DLSYM(clGetEventProfilingInfo); + CL_ASSIGN_FROM_DLSYM(clGetImageInfo); + +#undef CL_ASSIGN_FROM_DLSYM return handle; } -- GitLab