diff --git a/lite/backends/opencl/cl_kernel/image/bilinear_interp_kernel.cl b/lite/backends/opencl/cl_kernel/image/bilinear_interp_kernel.cl new file mode 100644 index 0000000000000000000000000000000000000000..9427692f1267d363222295b33b6834e28517d0a4 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/bilinear_interp_kernel.cl @@ -0,0 +1,96 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + + +__kernel void bilinear_interp(__read_only image2d_t input, + __write_only image2d_t output, + __private const float scale_h, + __private const float scale_w, + __private const float align_delta, + __private const int in_dims_h, + __private const int in_dims_w, + __private const int out_dims_h, + __private const int out_dims_w){ + const int c = get_global_id(0); + const int w = get_global_id(1); + const int nh = get_global_id(2); + + int2 output_pos; + output_pos.x = c * out_dims_w + w; + output_pos.y = nh; + + // calculate center pixel's pos + int out_n = nh / out_dims_h; + int out_h = nh % out_dims_h; + float center_w = (w + align_delta) * scale_w - align_delta; + float center_h = (out_h + align_delta) * scale_h - align_delta; + + int floor_w = (int)center_w; + int floor_h = (int)center_h; + int ceil_w = floor_w + 1; + int ceil_h = floor_h + 1; + if (floor_w < 0){ + floor_w = 0; + } + if (floor_h < 0){ + floor_h = 0; + } + if (ceil_w > in_dims_w - 1) { + ceil_w = in_dims_w - 1; + } + if (ceil_h > in_dims_h - 1) { + ceil_h = in_dims_h- 1; + } + float wight0_w = center_w - floor_w; + float wight0_h = center_h - floor_h; + float wight1_w = 1.0 - wight0_w; + float wight1_h = 1.0 - wight0_h; + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + // get left up pixel data + int2 left_up; + left_up.x = c * in_dims_w + floor_w; + left_up.y = out_n * in_dims_h + ceil_h; + CL_DTYPE4 left_up_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, left_up); + + + // get left down pixel data + int2 left_down; + left_down.x = c * in_dims_w + floor_w; + left_down.y = out_n * in_dims_h + floor_h; + CL_DTYPE4 left_down_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, left_down); + + // get right up pixel data + int2 right_up; + right_up.x = c * in_dims_w + ceil_w; + right_up.y = out_n * in_dims_h + ceil_h; + CL_DTYPE4 right_up_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, right_up); + + // get right down pixel's data + int2 right_down; + right_down.x = c * in_dims_w + ceil_w; + right_down.y = out_n * in_dims_h + floor_h; + CL_DTYPE4 right_down_data = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, right_down); + + // calculate output data + CL_DTYPE4 out = (left_down_data * wight1_w + right_down_data * wight0_w) * wight1_h + + (left_up_data * wight1_w + right_up_data * wight0_w) * wight0_h; + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, output_pos, out); +} diff --git a/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 index 39264d413595b357e0a96e9a9761aeb4ec82db90..d55de99adda1a76b020dffa79be8f6d536ad91d1 100644 --- a/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 +++ b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv7 @@ -28,7 +28,7 @@ OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/armeabi-v7a/include CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include -#CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS) +#CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) ############################################################### # How to use one of static libaray: # diff --git a/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 index 774fca9f67638450dabf53e13cee51bd7af9712e..70d6bed52b84be7d050ef15ab483e8d06342c82d 100644 --- a/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 +++ b/lite/demo/cxx/makefiles/test_cv/Makefile.android.armv8 @@ -28,7 +28,7 @@ OPENCV_INCLUDE = -I../../../third_party/${OPENCV_VERSION}/arm64-v8a/include CXX_INCLUDES = $(INCLUDES) ${OPENCV_INCLUDE} -I$(LITE_ROOT)/cxx/include -#CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_full_api_shared $(SYSTEM_LIBS) +#CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SYSTEM_LIBS) ############################################################### # How to use one of static libaray: # # `libpaddle_api_full_bundled.a` # diff --git a/lite/demo/cxx/test_cv/README.md b/lite/demo/cxx/test_cv/README.md index f9d22aeb5c1d4ea380f05c4eb71aa5327d80bfbe..21574a9bf9fd0ebb3ecf1663f49beed93fdf51bb 100644 --- a/lite/demo/cxx/test_cv/README.md +++ b/lite/demo/cxx/test_cv/README.md @@ -17,7 +17,7 @@ example: wget http://paddle-inference-dist.bj.bcebos.com/mobilenet_v1.tar.gz tar zxvf mobilenet_v1.tar.gz ./lite/tools/build.sh build_optimize_tool -./build.model_optimize_tool/lite/api/model_optimize_tool +./build.opt/lite/api/opt --optimize_out_type=naive_buffer --optimize_out=model_dir --model_dir=model_dir diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 05bea8d31f252ab40cae8523727a81a1432401c0..c11653f7212941c739f0e0b2152bd96d2fa1b11c 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -23,7 +23,7 @@ add_kernel(concat_opencl OPENCL basic SRCS concat_image_compute.cc DEPS ${cl_ker add_kernel(nearest_interp_opencl OPENCL basic SRCS nearest_interp_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(scale_opencl OPENCL basic SRCS scale_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(grid_sampler_opencl OPENCL basic SRCS grid_sampler_image_compute.cc DEPS ${cl_kernel_deps}) - +add_kernel(bilinear_interp_opencl OPENCL basic SRCS bilinear_interp_image_compute.cc DEPS ${cl_kernel_deps}) # extra # wait to add ... @@ -67,9 +67,11 @@ lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_comput DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context) lite_cc_test(test_grid_sampler_image_opencl SRCS grid_sampler_image_compute_test.cc - DEPS grid_sampler_opencl op_registry program context - ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) - + DEPS grid_sampler_opencl op_registry program context) + +lite_cc_test(test_bilinear_interp_image_opencl SRCS bilinear_interp_image_compute_test.cc + DEPS bilinear_interp_opencl op_registry program context) + ###################### # buffer kernel # ###################### diff --git a/lite/kernels/opencl/bilinear_interp_image_compute.cc b/lite/kernels/opencl/bilinear_interp_image_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..eeab8b043b3344b492fd9bafc3259e8d8ed08438 --- /dev/null +++ b/lite/kernels/opencl/bilinear_interp_image_compute.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/logging.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { +class BilinearInterpImageCompute + : public KernelLite { + public: + using param_t = operators::InterpolateParam; + + std::string doc() const override { + return "BilinearInterp using cl::Image2D(ImageDefault/RGBA), kFP16"; + } + + void PrepareForRun() override { + bilinear_interp_param_ = param_.get_mutable(); + + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/bilinear_interp_kernel.cl", build_options_); + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + } + + void Run() override { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto* x = bilinear_interp_param_->X; + auto* out = bilinear_interp_param_->Out; + float scale_h = 0.0; + float scale_w = 0.0; + auto out_dims = out->dims(); + auto in_dims = x->dims(); + + if (bilinear_interp_param_->align_corners) { + scale_h = (in_dims[2] - 1.0f) / (out_dims[2] - 1.0f); + scale_w = (in_dims[3] - 1.0f) / (out_dims[3] - 1.0f); + } else { + scale_h = in_dims[2] / static_cast(out_dims[2]); + scale_w = in_dims[3] / static_cast(out_dims[3]); + } + float align_delta = 0.0f; + if (!bilinear_interp_param_->align_corners && + bilinear_interp_param_->align_mode == 0) { + align_delta = 0.5f; + } + + int in_h = in_dims[2]; + int in_w = in_dims[3]; + int out_h = out_dims[2]; + int out_w = out_dims[3]; + + VLOG(4) << "x->target():" << TargetToStr(x->target()); + VLOG(4) << "out->target():" << TargetToStr(out->target()); + VLOG(4) << "x->dims():" << in_dims; + VLOG(4) << "out->dims():" << out_dims; + + auto out_image_shape = InitImageDimInfoWith(out_dims); + auto* x_img = x->data(); + // VLOG(4) << "x_image: " << x_img; + + auto* out_img = out->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + // VLOG(4) << "out_image: " << out_img; + VLOG(4) << "out_image_shape[w,h]: " << out_image_shape["width"] << " " + << out_image_shape["height"]; + + VLOG(4) << "scale_h: " << scale_h << ", scale_w: " << scale_w + << ", align_delta: " << align_delta; + VLOG(4) << "in_h: " << in_h << ", in_w: " << in_w; + VLOG(4) << "out_h: " << out_h << ", out_w: " << out_w; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + auto default_work_size = + DefaultWorkSize(out_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + VLOG(4) << "default_work_size: " << default_work_size[0] << ", " + << default_work_size[1] << ", " << default_work_size[2]; + cl_int status = kernel.setArg(arg_idx++, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, scale_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, scale_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, align_delta); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, in_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, in_w); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, out_h); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, out_w); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size[0]), + static_cast(default_work_size[1]), + static_cast(default_work_size[2])}; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); + + VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " + << global_work_size[1] << " " << global_work_size[2]; + } + + protected: + param_t* bilinear_interp_param_{nullptr}; + std::string kernel_func_name_{"bilinear_interp"}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace ocl = paddle::lite::kernels::opencl; +REGISTER_LITE_KERNEL(bilinear_interp, + kOpenCL, + kFP16, + kImageDefault, + ocl::BilinearInterpImageCompute, + ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/bilinear_interp_image_compute_test.cc b/lite/kernels/opencl/bilinear_interp_image_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dec202ef55d797ce270ef46c6f80cc8a3474936f --- /dev/null +++ b/lite/kernels/opencl/bilinear_interp_image_compute_test.cc @@ -0,0 +1,270 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { +void bilinear_interp_ref(const float* din, + const DDim& x_dims, + float* dout, + const DDim& out_dims, + bool align_corners, + int align_mode) { + int batch_size = x_dims[0]; + int channel_size = x_dims[1]; + auto in_h = x_dims[2]; + auto in_w = x_dims[3]; + + int out_h = out_dims[2]; + int out_w = out_dims[3]; + + // copy from x if no change + if (in_h == out_h && in_w == out_w) { + memcpy(dout, din, sizeof(float) * x_dims.production()); + return; + } + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + // naive bilinear interpolation + bool align_flag = (align_mode == 0 && !align_corners); + + for (int n = 0; n < batch_size; n++) { + float* dout_data = dout + n * channel_size * out_h * out_w; + const float* din_data = din + n * channel_size * in_h * in_w; + for (int c = 0; c < channel_size; c++) { + float* dout_data_c = dout_data + c * out_h * out_w; + const float* din_data_c = din_data + c * in_h * in_w; + for (int h = 0; h < out_h; h++) { + float center_h = align_flag ? (ratio_h * (h + 0.5) - 0.5) : ratio_h * h; + int floor_h = static_cast(center_h); + int ceil_h = floor_h + 1; + floor_h = floor_h > 0 ? floor_h : 0; + ceil_h = ceil_h > in_h - 1 ? in_h - 1 : ceil_h; + float hs = center_h - floor_h; + float he = 1.0 - hs; + for (int w = 0; w < out_w; w++) { + float center_w = + align_flag ? (ratio_w * (w + 0.5) - 0.5) : ratio_w * w; + int floor_w = static_cast(center_w); + int ceil_w = floor_w + 1; + floor_w = floor_w > 0 ? floor_w : 0; + ceil_w = ceil_w > in_w - 1 ? in_w - 1 : ceil_w; + float ws = center_w - floor_w; + float we = 1.0 - ws; + float left_up = din_data_c[ceil_h * in_w + floor_w] * we * hs; + float left_down = din_data_c[floor_h * in_w + floor_w] * we * he; + float right_up = din_data_c[ceil_h * in_w + ceil_w] * ws * hs; + float right_down = din_data_c[floor_h * in_w + ceil_w] * ws * he; + dout_data_c[h * out_w + w] = + left_up + left_down + right_up + right_down; + } + } + } + } +} +// #define BILINEAR_FP16_LOOP_TEST +// #define BILINEAR_FP16_PRINT_RESULT +TEST(bilinear_interp_image2d, compute) { +#ifdef BILINEAR_FP16_LOOP_TEST + for (auto n : {1, 3}) { + for (auto c : {1, 3, 8, 23, 32}) { + for (auto h : {2, 20, 64, 112}) { + for (auto w : {2, 20, 64, 112}) { + for (auto out_h : {4, 32, 96, 224}) { + for (auto out_w : {4, 32, 96, 224}) { + for (auto align_corners : {true, false}) { + for (auto align_mode : {0, 1}) { +#else + const int n = 1; + const int c = 1; + const int h = 2; + const int w = 2; + const int out_h = 4; + const int out_w = 4; + const bool align_corners = true; + const int align_mode = 0; +#endif // BILINEAR_FP16_LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c + << " " << h << " " << w << " ========"; + LOG(INFO) << "======== parameters: out_h = " << out_h + << ", out_w = " << out_w; + LOG(INFO) << "align_corners: " << align_corners + << ", align_mode: " << align_mode; + + auto kernels = KernelRegistry::Global().Create( + "bilinear_interp", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + LOG(INFO) << "get kernel:" << kernel->doc(); + + lite::Tensor x, out; + operators::InterpolateParam param; + param.X = &x; + param.Out = &out; + param.align_corners = align_corners; + param.align_mode = align_mode; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr bilinear_context( + new KernelContext); + context->As().CopySharedTo( + &(bilinear_context->As())); + kernel->SetContext(std::move(bilinear_context)); + + const DDim in_dim = + DDim(std::vector{n, c, h, w}); + const DDim out_dim = + DDim(std::vector{n, c, out_h, out_w}); + x.Resize(in_dim); + out.Resize(out_dim); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-1, 1); + int sum = n * c * h * w; + std::vector input_v(sum); + for (auto& i : input_v) { + i = dist(engine); + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = + new CLImageConverterDefault(); + DDim x_image_shape = + default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * + 4); // 4 : RGBA + default_converter->NCHWToImage( + input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + // LOG(INFO) << "x_image:" << x_image; + + DDim out_image_shape = + default_converter->InitImageDimInfoWith(out_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = out.mutable_data( + out_image_shape[0], out_image_shape[1]); + // LOG(INFO) << "out_image:" << out_image; + kernel->Launch(); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.Out->data(); + auto it = wait_list->find(out_ptr); + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl " + "tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) << "Could not find the sync event for the " + "target cl tensor."; + } + + std::unique_ptr out_ref( + new float[out_dim.production()]); + bilinear_interp_ref(input_v.data(), + in_dim, + out_ref.get(), + out_dim, + align_corners, + align_mode); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = + new half_t[40000]; // out_image_shape.production() * 4 + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + float* out_data = new float[out_image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, out_image_shape, out_dim); +// result +#ifdef BILINEAR_FP16_PRINT_RESULT + LOG(INFO) + << "---- print kernel result (input -> output) ----"; + for (int eidx = 0; eidx < in_dim.production(); ++eidx) { + std::cout << input_v[eidx] << " -> " << out_data[eidx] + << std::endl; + } +#endif // BILINEAR_FP16_PRINT_RESULT + for (int i = 0; i < out_dim.production(); i++) { + auto abs_diff = abs(out_data[i] - out_ref[i]); + auto relative_diff = + COMPUTE_RELATIVE_DIFF(out_data[i], out_ref[i]); + EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || + (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && + (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << ", in_data[" << i + << "]: " << input_v[i] << ", out_data[" << i + << "]: " << out_data[i] << ", out_ref[" << i + << "]: " << out_ref[i] + << ", abs_diff: " << abs_diff + << ", relative_diff: " << relative_diff + << ", FP16_MAX_DIFF: " << FP16_MAX_DIFF; + } + } +#ifdef BILINEAR_FP16_LOOP_TEST + } // mode + } // corners + } // out_w + } // out_h + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(bilinear_interp, kOpenCL, kFP16, kImageDefault, ImageDefault); diff --git a/lite/kernels/opencl/grid_sampler_image_compute.cc b/lite/kernels/opencl/grid_sampler_image_compute.cc index 4ac36112e3fec11b8d1cbde268eb6fc5e0bcf518..e174286ca1fefa3c56bca04b433015ac769cfcbf 100644 --- a/lite/kernels/opencl/grid_sampler_image_compute.cc +++ b/lite/kernels/opencl/grid_sampler_image_compute.cc @@ -35,7 +35,7 @@ class GridSamplerImageCompute : public KernelLiteAs(); context.cl_context()->AddKernel( kernel_func_name_, "image/grid_sampler_kernel.cl", build_options_); - VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + VLOG(4) << "kernel_func_name_:" << kernel_func_name_; } void Run() override { @@ -64,14 +64,14 @@ class GridSamplerImageCompute : public KernelLitedata(); - VLOG(4) << "x_image: " << x_img; + // VLOG(4) << "x_image: " << x_img; auto* grid_img = x->data(); - VLOG(4) << "grid_img: " << grid_img; + // VLOG(4) << "grid_img: " << grid_img; auto* out_img = out->mutable_data( out_image_shape["width"], out_image_shape["height"]); - VLOG(4) << "out_image" << out_img; + // VLOG(4) << "out_image" << out_img; VLOG(4) << "out_image_shape[w,h]:" << out_image_shape["width"] << " " << out_image_shape["height"]; @@ -87,7 +87,8 @@ class GridSamplerImageCompute : public KernelLite{ static_cast(out_image_shape["width"]), static_cast(out_image_shape["height"])})); - + VLOG(4) << "default_work_size: " << default_work_size[0] << ", " + << default_work_size[1] << ", " << default_work_size[2]; cl_int status = kernel.setArg(arg_idx++, *x_img); CL_CHECK_FATAL(status); status = kernel.setArg(arg_idx++, *grid_img); @@ -101,8 +102,8 @@ class GridSamplerImageCompute : public KernelLite(default_work_size[0]), - static_cast(default_work_size[2]), - static_cast(default_work_size[3] / 4)}; + static_cast(default_work_size[1]), + static_cast(default_work_size[2] / 4)}; status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( kernel, diff --git a/lite/kernels/opencl/grid_sampler_image_compute_test.cc b/lite/kernels/opencl/grid_sampler_image_compute_test.cc index bba05cac1e0facd761e966686b26e880f2e99964..afdebc0e64f3503a95bd14f00207920ed6363cf5 100644 --- a/lite/kernels/opencl/grid_sampler_image_compute_test.cc +++ b/lite/kernels/opencl/grid_sampler_image_compute_test.cc @@ -168,7 +168,7 @@ TEST(grid_samler_image2d, compute) { input_v.data(), x_image_data.data(), in_dim); auto* x_image = x.mutable_data( x_image_shape[0], x_image_shape[1], x_image_data.data()); - LOG(INFO) << "x_image:" << x_image; + // LOG(INFO) << "x_image:" << x_image; DDim grid_image_shape = default_converter->InitImageDimInfoWith(grid_dim); @@ -180,7 +180,7 @@ TEST(grid_samler_image2d, compute) { grid_v.data(), grid_image_data.data(), grid_dim); auto* grid_image = grid.mutable_data( grid_image_shape[0], grid_image_shape[1], grid_image_data.data()); - LOG(INFO) << "grid_image:" << grid_image; + // LOG(INFO) << "grid_image:" << grid_image; DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim); @@ -188,7 +188,7 @@ TEST(grid_samler_image2d, compute) { << out_image_shape[1]; auto* out_image = out.mutable_data( out_image_shape[0], out_image_shape[1]); - LOG(INFO) << "out_image:" << out_image; + // LOG(INFO) << "out_image:" << out_image; kernel->Launch(); auto* wait_list = context->As().cl_wait_list(); diff --git a/lite/kernels/opencl/reshape_image_compute.cc b/lite/kernels/opencl/reshape_image_compute.cc index e84fcfb45814ac4e9a7cd269c0737bcf2fbc63e1..557259e29fd73964d62b150f25ea89b1f5b16908 100644 --- a/lite/kernels/opencl/reshape_image_compute.cc +++ b/lite/kernels/opencl/reshape_image_compute.cc @@ -206,3 +206,38 @@ REGISTER_LITE_KERNEL(reshape2, PRECISION(kFP16), DATALAYOUT(kImageDefault))}) .Finalize(); + +REGISTER_LITE_KERNEL(flatten, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::ReshapeComputeFloatImage, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); + +REGISTER_LITE_KERNEL(flatten2, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::ReshapeComputeFloatImage, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindInput("Shape", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kOpenCL))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/operators/pad2d_op.cc b/lite/operators/pad2d_op.cc index 09deed89072512fa0e00bd0be080e8ff8f8a6cec..ff522b94b95091b6df6d4d2f71e18907c5118619 100644 --- a/lite/operators/pad2d_op.cc +++ b/lite/operators/pad2d_op.cc @@ -46,7 +46,20 @@ bool Pad2dOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { scope->FindVar(op_desc.Output("Out").front())->GetMutable(); param_.mode = op_desc.GetAttr("mode"); param_.pad_value = op_desc.GetAttr("pad_value"); - param_.paddings = op_desc.GetAttr>("paddings"); + if (op_desc.HasAttr("variable_padding") && + op_desc.GetAttr("variable_paddings")) { + auto Paddings = + scope->FindVar(op_desc.Input("Paddings").front())->GetMutable(); + auto ptr = Paddings->data(); + if (Paddings->dims().size() < 4) { + printf("Paddings size must be four: %d \n", + static_cast(Paddings->dims().size())); + return false; + } + param_.paddings = {ptr[0], ptr[1], ptr[2], ptr[3]}; + } else { + param_.paddings = op_desc.GetAttr>("paddings"); + } param_.data_format = op_desc.GetAttr("data_format"); return true; } diff --git a/lite/utils/cv/image_resize.cc b/lite/utils/cv/image_resize.cc index e9ff8dbbfc1508e53f80250dc5a3071ef1e63661..39c50e78dd76a47e9e0789b91e615f20297d9f70 100644 --- a/lite/utils/cv/image_resize.cc +++ b/lite/utils/cv/image_resize.cc @@ -51,6 +51,7 @@ void compute_xy(int srcw, int srch, int dstw, int dsth, + int num, double scale_x, double scale_y, int* xofs, @@ -77,8 +78,8 @@ void resize(const uint8_t* src, memcpy(dst, src, sizeof(uint8_t) * size); return; } - double scale_x = static_cast(srcw / dstw); - double scale_y = static_cast(srch / dsth); + double scale_x = static_cast(srcw) / dstw; + double scale_y = static_cast(srch) / dsth; int* buf = new int[dstw * 2 + dsth * 2]; @@ -87,9 +88,6 @@ void resize(const uint8_t* src, int16_t* ialpha = reinterpret_cast(buf + dstw + dsth); int16_t* ibeta = reinterpret_cast(buf + 2 * dstw + dsth); - compute_xy( - srcw, srch, dstw, dsth, scale_x, scale_y, xofs, yofs, ialpha, ibeta); - int w_out = dstw; int w_in = srcw; int num = 1; @@ -111,6 +109,9 @@ void resize(const uint8_t* src, num = 4; } + compute_xy( + srcw, srch, dstw, dsth, num, scale_x, scale_y, xofs, yofs, ialpha, ibeta); + int* xofs1 = nullptr; int* yofs1 = nullptr; int16_t* ialpha1 = nullptr; @@ -124,6 +125,7 @@ void resize(const uint8_t* src, srch / 2, w, tmp, + num, scale_x, scale_y, xofs1, @@ -134,6 +136,7 @@ void resize(const uint8_t* src, int cnt = w_out >> 3; int remain = w_out % 8; int32x4_t _v2 = vdupq_n_s32(2); + int prev_sy1 = -1; #pragma omp parallel for for (int dy = 0; dy < dsth; dy++) { int16_t* rowsbuf0 = new int16_t[w_out]; @@ -144,27 +147,20 @@ void resize(const uint8_t* src, yofs = yofs1; ialpha = ialpha1; } - if (sy < 0) { + if (sy == prev_sy1) { memset(rowsbuf0, 0, sizeof(uint16_t) * w_out); const uint8_t* S1 = src + srcw * (sy + 1); const int16_t* ialphap = ialpha; int16_t* rows1p = rowsbuf1; for (int dx = 0; dx < dstw; dx++) { - int sx = xofs[dx] * num; // num = 4 + int sx = xofs[dx]; int16_t a0 = ialphap[0]; int16_t a1 = ialphap[1]; const uint8_t* S1pl = S1 + sx; const uint8_t* S1pr = S1 + sx + num; - if (sx < 0) { - S1pl = S1; - } for (int i = 0; i < num; i++) { - if (sx < 0) { - *rows1p++ = ((*S1pl++) * a1) >> 4; - } else { - *rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4; - } + *rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4; } ialphap += 2; } @@ -176,7 +172,7 @@ void resize(const uint8_t* src, int16_t* rows0p = rowsbuf0; int16_t* rows1p = rowsbuf1; for (int dx = 0; dx < dstw; dx++) { - int sx = xofs[dx] * num; // num = 4 + int sx = xofs[dx]; int16_t a0 = ialphap[0]; int16_t a1 = ialphap[1]; @@ -184,32 +180,21 @@ void resize(const uint8_t* src, const uint8_t* S0pr = S0 + sx + num; const uint8_t* S1pl = S1 + sx; const uint8_t* S1pr = S1 + sx + num; - if (sx < 0) { - S0pl = S0; - S1pl = S1; - } for (int i = 0; i < num; i++) { - if (sx < 0) { - *rows0p = ((*S0pl++) * a1) >> 4; - *rows1p = ((*S1pl++) * a1) >> 4; - rows0p++; - rows1p++; - } else { - *rows0p++ = ((*S0pl++) * a0 + (*S0pr++) * a1) >> 4; - *rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4; - } + *rows0p++ = ((*S0pl++) * a0 + (*S0pr++) * a1) >> 4; + *rows1p++ = ((*S1pl++) * a0 + (*S1pr++) * a1) >> 4; } ialphap += 2; } } - int ind = dy * 2; - int16_t b0 = ibeta[ind]; - int16_t b1 = ibeta[ind + 1]; - int16x8_t _b0 = vdupq_n_s16(b0); - int16x8_t _b1 = vdupq_n_s16(b1); + prev_sy1 = sy + 1; + int16_t b0 = ibeta[0]; + int16_t b1 = ibeta[1]; uint8_t* dp_ptr = dst + dy * w_out; int16_t* rows0p = rowsbuf0; int16_t* rows1p = rowsbuf1; + int16x8_t _b0 = vdupq_n_s16(b0); + int16x8_t _b1 = vdupq_n_s16(b1); int re_cnt = cnt; if (re_cnt > 0) { #ifdef __aarch64__ @@ -295,6 +280,7 @@ void resize(const uint8_t* src, (int16_t)((b1 * (int16_t)(*rows1p++)) >> 16) + 2) >> 2); } + ibeta += 2; } delete[] buf; } @@ -303,6 +289,7 @@ void compute_xy(int srcw, int srch, int dstw, int dsth, + int num, double scale_x, double scale_y, int* xofs, @@ -334,7 +321,7 @@ void compute_xy(int srcw, fx = 1.f; } - xofs[dx] = sx; + xofs[dx] = sx * num; float a0 = (1.f - fx) * resize_coef_scale; float a1 = fx * resize_coef_scale;