From 50b9e85f3d95fbe51b63b0ac842c6c495ca47363 Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Wed, 4 Mar 2020 14:28:15 +0800 Subject: [PATCH] [opencl] add grid_sampler op (#3075) * add grid sampler * reset act * fix conflict and readme, test=develop * fix ios v8 run error * fix grid_sampler compute error. test=develop * fix exp act run error, test=develop * fix format, test=develop --- lite/api/paddle_place.h | 3 +- lite/backends/arm/math/conv_block_utils.h | 2 +- .../cl_kernel/image/activation_kernel.cl | 37 ++- .../cl_kernel/image/grid_sampler_kernel.cl | 168 ++++++++++++ lite/kernels/opencl/CMakeLists.txt | 4 + .../opencl/activation_image_compute.cc | 49 +++- .../opencl/activation_image_compute_test.cc | 38 ++- .../opencl/grid_sampler_image_compute.cc | 151 ++++++++++ .../opencl/grid_sampler_image_compute_test.cc | 258 ++++++++++++++++++ lite/utils/cv/image2tensor.cc | 40 +-- lite/utils/cv/image_convert.cc | 80 +++--- lite/utils/cv/image_resize.cc | 12 +- 12 files changed, 762 insertions(+), 80 deletions(-) create mode 100644 lite/backends/opencl/cl_kernel/image/grid_sampler_kernel.cl create mode 100644 lite/kernels/opencl/grid_sampler_image_compute.cc create mode 100644 lite/kernels/opencl/grid_sampler_image_compute_test.cc diff --git a/lite/api/paddle_place.h b/lite/api/paddle_place.h index 7da52adc7f..d06ee87966 100644 --- a/lite/api/paddle_place.h +++ b/lite/api/paddle_place.h @@ -96,7 +96,8 @@ enum class ActivationType : int { kLeakyRelu = 4, kSigmoid = 5, kTanh = 6, - kSwish = 7 + kSwish = 7, + kExp = 8 }; static size_t PrecisionTypeLength(PrecisionType type) { diff --git a/lite/backends/arm/math/conv_block_utils.h b/lite/backends/arm/math/conv_block_utils.h index a7ee4093dd..c4fb51021e 100644 --- a/lite/backends/arm/math/conv_block_utils.h +++ b/lite/backends/arm/math/conv_block_utils.h @@ -2203,7 +2203,7 @@ inline void act_switch_c8_fp32(const float* din_ptr, [cnt] "+r"(cnt_loop) : : "cc", - "meemory", + "memory", "q0", "q1", "q2", diff --git a/lite/backends/opencl/cl_kernel/image/activation_kernel.cl b/lite/backends/opencl/cl_kernel/image/activation_kernel.cl index a6ebaab97d..cc71df1c30 100644 --- a/lite/backends/opencl/cl_kernel/image/activation_kernel.cl +++ b/lite/backends/opencl/cl_kernel/image/activation_kernel.cl @@ -97,7 +97,7 @@ __kernel void leaky_relu(__read_only image2d_t input, WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), in); } -__kernel void tanhAct(__read_only image2d_t input, +__kernel void tanh_act(__read_only image2d_t input, __write_only image2d_t output, __private const float threshold, __private const float scale) { @@ -113,3 +113,38 @@ __kernel void tanhAct(__read_only image2d_t input, CL_DTYPE4 out= (exp(in) - exp(-in))/ (exp(in) + exp(-in)); WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); } + +__kernel void exp_act(__read_only image2d_t input, + __write_only image2d_t output, + __private const float threshold, + __private const float scale) { + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); + CL_DTYPE4 out = exp(in); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); +} + +__kernel void swish(__read_only image2d_t input, + __write_only image2d_t output, + __private const float threshold, + __private const float scale) { + + const int x = get_global_id(0); // image_width + const int y = get_global_id(1); // image_height + + const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | + CLK_ADDRESS_CLAMP | + CLK_FILTER_NEAREST; + + CL_DTYPE4 in = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x, y)); + CL_DTYPE4 out = in / (1 + exp(-scale * in)); + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(x, y), out); +} + diff --git a/lite/backends/opencl/cl_kernel/image/grid_sampler_kernel.cl b/lite/backends/opencl/cl_kernel/image/grid_sampler_kernel.cl new file mode 100644 index 0000000000..360d8c753e --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/grid_sampler_kernel.cl @@ -0,0 +1,168 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +__kernel void grid_sampler(__read_only image2d_t input, + __read_only image2d_t grid, + __write_only image2d_t output, + __private const int out_height, + __private const int out_width) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2) * 4; + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int2 coords1, coords2, outpoints; + coords1.x = out_h / 4 * 2; + coords1.y = out_n * out_width + out_w; + coords2.x = coords1.x + 1; + coords2.y = coords1.y; + outpoints.x = out_c * out_width + out_w; + outpoints.x = out_n * out_height + out_h; + + CL_DTYPE4 g1 = READ_IMG_TYPE(CL_DTYPE_CHAR, grid, sampler, coords1); + CL_DTYPE4 g2 = READ_IMG_TYPE(CL_DTYPE_CHAR, grid, sampler, coords2); + + // x + float x = (g1.x + 1) * (out_width - 1) * 0.5; + float y = (g2.x + 1) * (out_height - 1) * 0.5; + int x0 = floor(x); + int y0 = floor(y); + int x_p = out_c * out_width + x0; + int y_p = out_n * out_height + y0; + + float xs = x - x0; + float xe = x0 + 1 - x; + float ys = y - y0; + float ye = y0 + 1 - y; + + CL_DTYPE4 input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p)); + CL_DTYPE4 input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p)); + CL_DTYPE4 input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p + 1)); + CL_DTYPE4 input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p + 1)); + + if (x0 < 0 || x0 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input0 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input1 = (CL_DTYPE4)(0.0); + } + if (x0 < 0 || x0 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input2 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input3 = (CL_DTYPE4)(0.0); + } + CL_DTYPE4 out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, outpoints, out_val); + + // y + x = (g1.y + 1) * (out_width - 1) / 2; + y = (g2.y + 1) * (out_height - 1) / 2; + x0 = floor(x); + y0 = floor(y); + x_p = out_c * out_width + x0; + y_p = out_n * out_height + y0; + + xs = x - x0; + xe = x0 + 1 - x; + ys = y - y0; + ye = y0 + 1 - y; + + input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p)); + input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p)); + input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p + 1)); + input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p + 1)); + + if (x0 < 0 || x0 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input0 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input1 = (CL_DTYPE4)(0.0); + } + if (x0 < 0 || x0 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input2 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input3 = (CL_DTYPE4)(0.0); + } + + out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y + 1), out_val); + + // z + x = (g1.z + 1) * (out_width - 1) / 2; + y = (g2.z + 1) * (out_height - 1) / 2; + x0 = floor(x); + y0 = floor(y); + x_p = out_c * out_width + x0; + y_p = out_n * out_height + y0; + + xs = x - x0; + xe = x0 + 1 - x; + ys = y - y0; + ye = y0 + 1 - y; + + input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p)); + input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p)); + input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p + 1)); + input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p + 1)); + + if (x0 < 0 || x0 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input0 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input1 = (CL_DTYPE4)(0.0); + } + if (x0 < 0 || x0 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input2 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input3 = (CL_DTYPE4)(0.0); + } + out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y + 2), out_val); + + // w + x = (g1.w + 1) * (out_width - 1) / 2; + y = (g2.w + 1) * (out_height - 1) / 2; + x0 = floor(x); + y0 = floor(y); + x_p = out_c * out_width + x0; + y_p = out_n * out_height + y0; + + xs = x - x0; + xe = x0 + 1 - x; + ys = y - y0; + ye = y0 + 1 - y; + + input0 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p)); + input1 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p)); + input2 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p, y_p + 1)); + input3 = READ_IMG_TYPE(CL_DTYPE_CHAR, input, sampler, (int2)(x_p + 1, y_p + 1)); + + if (x0 < 0 || x0 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input0 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 < 0 || y0 > out_height - 1){ + input1 = (CL_DTYPE4)(0.0); + } + if (x0 < 0 || x0 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input2 = (CL_DTYPE4)(0.0); + } + if (x0 + 1 < 0 || x0 + 1 > out_width - 1 || y0 + 1 < 0 || y0 + 1 > out_height - 1){ + input3 = (CL_DTYPE4)(0.0); + } + out_val = input0 * xe * ye + input1 * xs * ye + input2 * xe * ys + input3 * xs * ys; + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(outpoints.x, outpoints.y + 3), out_val); +} diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index c1a2afdabb..716ab35050 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -22,6 +22,7 @@ add_kernel(layout_opencl OPENCL basic SRCS layout_image_compute.cc DEPS ${cl_ker add_kernel(concat_opencl OPENCL basic SRCS concat_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(nearest_interp_opencl OPENCL basic SRCS nearest_interp_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(scale_opencl OPENCL basic SRCS scale_image_compute.cc DEPS ${cl_kernel_deps}) +add_kernel(grid_sampler_opencl OPENCL basic SRCS grid_sampler_image_compute.cc DEPS ${cl_kernel_deps}) # extra # wait to add ... @@ -76,6 +77,9 @@ lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_comput DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) +lite_cc_test(test_grid_sampler_image_opencl SRCS grid_sampler_image_compute_test.cc + DEPS grid_sampler_opencl op_registry program context + ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl) ###################### # buffer kernel # diff --git a/lite/kernels/opencl/activation_image_compute.cc b/lite/kernels/opencl/activation_image_compute.cc index 022e71002f..5a7100c37e 100644 --- a/lite/kernels/opencl/activation_image_compute.cc +++ b/lite/kernels/opencl/activation_image_compute.cc @@ -56,7 +56,14 @@ class ActivationComputeImageDefault kernel_func_name_ = "sigmoid"; break; case 6: - kernel_func_name_ = "tanhAct"; + kernel_func_name_ = "tanh_act"; + break; + case 7: + kernel_func_name_ = "swish"; + scale_ = act_param_->Swish_beta; + break; + case 8: + kernel_func_name_ = "exp_act"; break; default: printf("This act type: %d doesn't support \n", act_type); @@ -80,7 +87,6 @@ class ActivationComputeImageDefault STL::stringstream kernel_key; kernel_key << kernel_func_name_ << build_options_; auto kernel = context.cl_context()->GetKernel(kernel_key.str()); - int arg_idx = 0; cl_int status = kernel.setArg(arg_idx, *x_img); CL_CHECK_FATAL(status); @@ -147,9 +153,45 @@ REGISTER_LITE_KERNEL( DATALAYOUT(kImageDefault))}) .Finalize(); +// swish +REGISTER_LITE_KERNEL( + swish, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::ActivationComputeImageDefault, + ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); + +// exp +REGISTER_LITE_KERNEL( + exp_act, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::ActivationComputeImageDefault, + ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); + // tanh REGISTER_LITE_KERNEL( - tanhAct, + tanh_act, kOpenCL, kFP16, kImageDefault, @@ -164,6 +206,7 @@ REGISTER_LITE_KERNEL( PRECISION(kFP16), DATALAYOUT(kImageDefault))}) .Finalize(); + // Relu REGISTER_LITE_KERNEL( relu, diff --git a/lite/kernels/opencl/activation_image_compute_test.cc b/lite/kernels/opencl/activation_image_compute_test.cc index b21c433648..40751a44b2 100644 --- a/lite/kernels/opencl/activation_image_compute_test.cc +++ b/lite/kernels/opencl/activation_image_compute_test.cc @@ -51,13 +51,19 @@ void act_compute_ref(const dtype *x_data, out_data[i] = (expf(x_data[i]) - expf(-x_data[i])) / (expf(x_data[i]) + expf(-x_data[i])); break; + case 7: // swish + out_data[i] = x_data[i] / (1 + expf(-x_data[i] * scale)); + break; + case 8: // exp + out_data[i] = expf(x_data[i]); + break; default: break; } } } -// #define ACT_FP16_LOOP_TEST +// #define ACT_FP16_LOOP_TEST // #define ACT_FP16_PRINT_RESULT TEST(act_image2d_fp16, compute) { LOG(INFO) << "main steps of test: host -> layout(buf2img) -> relu(img) -> " @@ -69,7 +75,7 @@ TEST(act_image2d_fp16, compute) { for (auto c : {1, 3, 8, 23, 32}) { for (int h = 12; h <= 100; h += 13) { for (int w = 12; w <= 100; w += 25) { - for (auto act_type : {1, 2, 4, 5, 6}) { + for (auto act_type : {1, 2, 4, 5, 6, 7, 8}) { for (auto scale : {0.5, 0.8}) { for (auto threshold : {6.0}) { #else @@ -103,7 +109,13 @@ TEST(act_image2d_fp16, compute) { func_name = "sigmoid"; break; case 6: // tanh - func_name = "tanhAct"; + func_name = "tanh_act"; + break; + case 7: // tanh + func_name = "swish"; + break; + case 8: // tanh + func_name = "exp_act"; break; } LOG(INFO) << "func_name: " << func_name; @@ -153,6 +165,7 @@ TEST(act_image2d_fp16, compute) { (paddle::lite_api::ActivationType)act_type; actParam.Relu_clipped_coef = threshold; actParam.Leaky_relu_alpha = scale; + actParam.Swish_beta = scale; const DDim x_dim = DDim(std::vector{n, c, h, w}); @@ -175,9 +188,11 @@ TEST(act_image2d_fp16, compute) { x_data, 0, sizeof(float) * x_dim.production())); auto *mapped_y = static_cast(TargetWrapperCL::Map( y_data, 0, sizeof(float) * x_dim.production())); + std::default_random_engine engine; + std::uniform_real_distribution dist(-1, 1); for (int i = 0; i < x_dim.production(); ++i) { - mapped_x[i] = static_cast(i) - x_dim.production() / 2; - mapped_y[i] = static_cast(0); + mapped_x[i] = dist(engine); + mapped_y[i] = 0.0f; } auto *act_in_data = act_in.mutable_data( act_image2d_shape["width"], act_image2d_shape["height"]); @@ -290,11 +305,18 @@ TEST(act_image2d_fp16, compute) { // layout USE_LITE_KERNEL(layout, kOpenCL, kAny, kImageDefault, NCHW_to_ImageDefault); USE_LITE_KERNEL(layout, kOpenCL, kAny, kNCHW, ImageDefault_to_NCHW); -// leakyRelu + +// exp +USE_LITE_KERNEL(exp_act, kOpenCL, kFP16, kImageDefault, ImageDefault); + +// swish +USE_LITE_KERNEL(swish, kOpenCL, kFP16, kImageDefault, ImageDefault); + +// leaky_relu USE_LITE_KERNEL(leaky_relu, kOpenCL, kFP16, kImageDefault, ImageDefault); -// tanh -USE_LITE_KERNEL(tanhAct, kOpenCL, kFP16, kImageDefault, ImageDefault); +// tanh act +USE_LITE_KERNEL(tanh_act, kOpenCL, kFP16, kImageDefault, ImageDefault); // relu image2d fp16 USE_LITE_KERNEL(relu, kOpenCL, kFP16, kImageDefault, ImageDefault); diff --git a/lite/kernels/opencl/grid_sampler_image_compute.cc b/lite/kernels/opencl/grid_sampler_image_compute.cc new file mode 100644 index 0000000000..a827350c28 --- /dev/null +++ b/lite/kernels/opencl/grid_sampler_image_compute.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/logging.h" +#include "lite/utils/replace_stl/stream.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { +class GridSamplerImageCompute : public KernelLite { + public: + using param_t = operators::GridSamplerParam; + + std::string doc() const override { + return "GridSampler using cl::Image2D(ImageDefault/RGBA), kFP32"; + } + + void PrepareForRun() override { + grid_param_ = param_.get_mutable(); + + auto& context = ctx_->As(); + context.cl_context()->AddKernel( + kernel_func_name_, "image/grid_sampler_kernel.cl", build_options_); + } + + void Run() override { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + auto* x = grid_param_->x; + auto* out = grid_param_->out; + auto* grid = grid_param_->grid; + auto out_dims = out->dims(); + auto in_dims = x->dims(); + + VLOG(4) << "x->target():" << TargetToStr(x->target()); + VLOG(4) << "out->target():" << TargetToStr(out->target()); + VLOG(4) << "x->dims():" << in_dims; + VLOG(4) << "out->dims():" << out_dims; + + auto out_image_shape = InitImageDimInfoWith(out_dims); + auto* x_img = x->data(); + VLOG(4) << "x_image: " << x_img; + + auto* grid_img = x->data(); + VLOG(4) << "grid_img: " << grid_img; + + auto* out_img = out->mutable_data( + out_image_shape["width"], out_image_shape["height"]); + VLOG(4) << "out_image" << out_img; + VLOG(4) << "out_image_shape[w,h]:" << out_image_shape["width"] << " " + << out_image_shape["height"]; + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_; + auto kernel = context.cl_context()->GetKernel(kernel_key.str()); + + int arg_idx = 0; + int out_height = out_dims[2]; + int out_width = out_dims[3]; + auto default_work_size = + DefaultWorkSize(out_dims, + DDim(std::vector{ + static_cast(out_image_shape["width"]), + static_cast(out_image_shape["height"])})); + + cl_int status = kernel.setArg(arg_idx++, *x_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *grid_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, *out_img); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, out_height); + CL_CHECK_FATAL(status); + status = kernel.setArg(arg_idx++, out_width); + CL_CHECK_FATAL(status); + + auto global_work_size = + cl::NDRange{static_cast(default_work_size[0]), + static_cast(default_work_size[2]), + static_cast(default_work_size[3] / 4)}; + + status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel( + kernel, + cl::NullRange, + global_work_size, + cl::NullRange, + nullptr, + event_.get()); + CL_CHECK_FATAL(status); + context.cl_wait_list()->emplace(out_img, event_); + + VLOG(4) << "global_work_size:[2D]:" << global_work_size[0] << " " + << global_work_size[1] << " " << global_work_size[2]; + } + + protected: + param_t* grid_param_{nullptr}; + std::string kernel_func_name_{"grid_sampler"}; + std::string build_options_{"-DCL_DTYPE_half"}; + std::shared_ptr event_{new cl::Event}; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +namespace ocl = paddle::lite::kernels::opencl; +REGISTER_LITE_KERNEL(grid_sampler, + kOpenCL, + kFP16, + kImageDefault, + ocl::GridSamplerImageCompute, + ImageDefault) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindInput("Grid", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Output", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/grid_sampler_image_compute_test.cc b/lite/kernels/opencl/grid_sampler_image_compute_test.cc new file mode 100644 index 0000000000..bba05cac1e --- /dev/null +++ b/lite/kernels/opencl/grid_sampler_image_compute_test.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" + +#define FP16_MAX_DIFF (5e-1) + +namespace paddle { +namespace lite { + +void gird_sampler_ref(const float* din, + const DDim& in_dims, + const float* grid, + float* output) { + int num = in_dims[0]; + int channel = in_dims[1]; + int height = in_dims[2]; + int width = in_dims[3]; + int spatial_size = height * width; + + auto inbound = [](int x, int y, float x_max, float y_max) { + if (x < 0 || x > x_max || y < 0 || y > y_max) { + return false; + } + return true; + }; + + for (int n = 0; n < num; ++n) { + const float* x_n = din + n * channel * height * width; + float* out_n = output + n * channel * height * width; + const float* grid_n = grid + n * height * width * 2; + for (int c = 0; c < channel; ++c) { + const float* x_c = x_n + c * spatial_size; + float* out_c = out_n + c * spatial_size; + for (int s = 0; s < spatial_size; ++s) { + float x = grid_n[s * 2]; + float y = grid_n[s * 2 + 1]; + float xwf = (x + 1.f) * 0.5 * (width - 1); + float ynf = (y + 1.f) * 0.5 * (height - 1); + int xw = floor(xwf); + int xe = xw + 1; + int yn = floor(ynf); + int ys = yn + 1; + + float dw = xwf - xw; + float de = xe - xwf; + float dn = ynf - yn; + float ds = ys - ynf; + + float wn = inbound(xw, + yn, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[yn * width + xw] + : 0.f; + float en = inbound(xe, + yn, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[yn * width + xe] + : 0.f; + float ws = inbound(xw, + ys, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[ys * width + xw] + : 0.f; + float es = inbound(xe, + ys, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[ys * width + xe] + : 0.f; + + out_c[s] = wn * de * ds + en * dw * ds + ws * de * dn + es * dw * dn; + } + } + } +} +// #define GRID_FP16_LOOP_TEST +// #define GRID_FP16_PRINT_RESULT +TEST(grid_samler_image2d, compute) { +#ifdef GRID_FP16_LOOP_TEST + for (int n = 1; n <= 100; n += 33) { + for (auto c : {1, 3, 8, 23, 32}) { + for (int h = 12; h <= 100; h += 13) { + for (int w = 12; w <= 100; w += 25) { +#else + const int n = 1; + const int c = 2; + const int h = 3; + const int w = 4; +#endif // GRID_FP16_LOOP_TEST + + LOG(INFO) << "======== input shape[n,c,h,w]:" << n << " " << c << " " + << h << " " << w << " ========"; + + auto kernels = + KernelRegistry::Global().Create("grid_sampler", + TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + LOG(INFO) << "get kernel:" << kernel->doc(); + + lite::Tensor x, grid, out; + operators::GridSamplerParam param; + param.x = &x; + param.grid = &grid; + param.out = &out; + + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + + kernel->SetParam(param); + std::unique_ptr grid_context(new KernelContext); + context->As().CopySharedTo( + &(grid_context->As())); + kernel->SetContext(std::move(grid_context)); + + const DDim in_dim = DDim(std::vector{n, c, h, w}); + const DDim grid_dim = DDim(std::vector{n, h, w, 2}); + const DDim out_dim = DDim(std::vector{n, c, h, w}); + x.Resize(in_dim); + grid.Resize(grid_dim); + out.Resize(out_dim); + + std::default_random_engine engine; + std::uniform_real_distribution dist(-1, 1); + int sum = n * c * h * w; + int sum2 = n * h * w * 2; + std::vector input_v(sum); + std::vector grid_v(sum2); + for (auto& i : input_v) { + i = dist(engine); + } + for (auto& i : grid_v) { + i = dist(engine); + } + + LOG(INFO) << "prepare input"; + CLImageConverterDefault* default_converter = + new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(in_dim); + LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " " + << x_image_shape[1]; + std::vector x_image_data(x_image_shape.production() * + 4); // 4 : RGBA + default_converter->NCHWToImage( + input_v.data(), x_image_data.data(), in_dim); + auto* x_image = x.mutable_data( + x_image_shape[0], x_image_shape[1], x_image_data.data()); + LOG(INFO) << "x_image:" << x_image; + + DDim grid_image_shape = + default_converter->InitImageDimInfoWith(grid_dim); + LOG(INFO) << "grid_image_shape = " << grid_image_shape[0] << " " + << grid_image_shape[1]; + std::vector grid_image_data(grid_image_shape.production() * + 4); // 4 : RGBA + default_converter->NCHWToImage( + grid_v.data(), grid_image_data.data(), grid_dim); + auto* grid_image = grid.mutable_data( + grid_image_shape[0], grid_image_shape[1], grid_image_data.data()); + LOG(INFO) << "grid_image:" << grid_image; + + DDim out_image_shape = + default_converter->InitImageDimInfoWith(out_dim); + LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + auto* out_image = out.mutable_data( + out_image_shape[0], out_image_shape[1]); + LOG(INFO) << "out_image:" << out_image; + kernel->Launch(); + + auto* wait_list = context->As().cl_wait_list(); + auto* out_ptr = param.out->data(); + auto it = wait_list->find(out_ptr); + if (it != wait_list->end()) { + VLOG(4) << "--- Find the sync event for the target cl tensor. ---"; + auto& event = *(it->second); + event.wait(); + } else { + LOG(FATAL) + << "Could not find the sync event for the target cl tensor."; + } + + std::unique_ptr out_ref(new float[out_dim.production()]); + gird_sampler_ref( + input_v.data(), in_dim, grid_v.data(), out_ref.get()); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + half_t* out_image_data = new half_t[out_image_shape.production() * 4]; + TargetWrapperCL::ImgcpySync(out_image_data, + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + float* out_data = new float[out_image_shape.production() * 4]; + default_converter->ImageToNCHW( + out_image_data, out_data, out_image_shape, out_dim); +// result +#ifdef GRID_FP16_PRINT_RESULT + LOG(INFO) << "---- print kernel result (input -> output) ----"; + for (int eidx = 0; eidx < in_dim.production(); ++eidx) { + std::cout << input_v[eidx] << " -> " << out_data[eidx] << std::endl; + } +#endif // GRID_FP16_PRINT_RESULT + for (int i = 0; i < out_dim.production(); i++) { + auto abs_diff = abs(out_data[i] - out_ref[i]); + auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_ref[i]); + EXPECT_EQ( + (relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF), + true); + if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) { + LOG(ERROR) << "error idx:" << i << " out_data[" << i + << "]:" << out_data[i] << " " + "out_ref[" + << i << "]:" << out_ref[i] << " abs_diff:" << abs_diff + << " relative_diff:" << relative_diff + << " FP16_MAX_DIFF:" << FP16_MAX_DIFF; + } + } +#ifdef GRID_FP16_LOOP_TEST + } // w + } // h + } // c + } // n +#else +// nothing to do. +#endif +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(grid_sampler, kOpenCL, kFP16, kImageDefault, ImageDefault); diff --git a/lite/utils/cv/image2tensor.cc b/lite/utils/cv/image2tensor.cc index 3a09039a0f..70f0da3f05 100644 --- a/lite/utils/cv/image2tensor.cc +++ b/lite/utils/cv/image2tensor.cc @@ -142,15 +142,15 @@ void gray_to_tensor(const uint8_t* src, "ucvtf v14.4s, v8.4s \n" "ucvtf v15.4s, v9.4s \n" // sub -mean - "fsub v12.4s, v12.4s, %w[vmean].4s \n" - "fsub v13.4s, v13.4s, %w[vmean].4s \n" - "fsub v14.4s, v14.4s, %w[vmean].4s \n" - "fsub v15.4s, v15.4s, %w[vmean].4s \n" + "fsub v12.4s, v12.4s, %[vmean].4s \n" + "fsub v13.4s, v13.4s, %[vmean].4s \n" + "fsub v14.4s, v14.4s, %[vmean].4s \n" + "fsub v15.4s, v15.4s, %[vmean].4s \n" // mul * scale - "fmul v6.4s, v12.4s, %w[vscale].4s \n" - "fmul v7.4s, v13.4s, %w[vscale].4s \n" - "fmul v8.4s, v14.4s, %w[vscale].4s \n" - "fmul v9.4s, v15.4s, %w[vscale].4s \n" + "fmul v6.4s, v12.4s, %[vscale].4s \n" + "fmul v7.4s, v13.4s, %[vscale].4s \n" + "fmul v8.4s, v14.4s, %[vscale].4s \n" + "fmul v9.4s, v15.4s, %[vscale].4s \n" // store "st1 {v6.4s}, [%[outr0]], #16 \n" "subs %w[cnt], %w[cnt], #1 \n" @@ -301,19 +301,19 @@ void bgr_to_tensor_chw(const uint8_t* src, "ucvtf v16.4s, v10.4s \n" "ucvtf v17.4s, v11.4s \n" // sub -mean - "fsub v12.4s, v12.4s, %w[vbmean].4s \n" - "fsub v13.4s, v13.4s, %w[vbmean].4s \n" - "fsub v14.4s, v14.4s, %w[vgmean].4s \n" - "fsub v15.4s, v15.4s, %w[vgmean].4s \n" - "fsub v16.4s, v16.4s, %w[vrmean].4s \n" - "fsub v17.4s, v17.4s, %w[vrmean].4s \n" + "fsub v12.4s, v12.4s, %[vbmean].4s \n" + "fsub v13.4s, v13.4s, %[vbmean].4s \n" + "fsub v14.4s, v14.4s, %[vgmean].4s \n" + "fsub v15.4s, v15.4s, %[vgmean].4s \n" + "fsub v16.4s, v16.4s, %[vrmean].4s \n" + "fsub v17.4s, v17.4s, %[vrmean].4s \n" // mul * scale - "fmul v6.4s, v12.4s, %w[vbscale].4s \n" - "fmul v7.4s, v13.4s, %w[vbscale].4s \n" - "fmul v8.4s, v14.4s, %w[vgscale].4s \n" - "fmul v9.4s, v15.4s, %w[vgscale].4s \n" - "fmul v10.4s, v16.4s, %w[vrscale].4s \n" - "fmul v11.4s, v17.4s, %w[vrscale].4s \n" + "fmul v6.4s, v12.4s, %[vbscale].4s \n" + "fmul v7.4s, v13.4s, %[vbscale].4s \n" + "fmul v8.4s, v14.4s, %[vgscale].4s \n" + "fmul v9.4s, v15.4s, %[vgscale].4s \n" + "fmul v10.4s, v16.4s, %[vrscale].4s \n" + "fmul v11.4s, v17.4s, %[vrscale].4s \n" // store "st1 {v6.4s}, [%[outr0]], #16 \n" "st1 {v8.4s}, [%[outr1]], #16 \n" diff --git a/lite/utils/cv/image_convert.cc b/lite/utils/cv/image_convert.cc index a17e5bf6e8..5953b871f4 100644 --- a/lite/utils/cv/image_convert.cc +++ b/lite/utils/cv/image_convert.cc @@ -829,15 +829,9 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { uint8x8_t vb = vdup_n_u8(b); uint8x8_t vg = vdup_n_u8(g); uint8x8_t vr = vdup_n_u8(r); -#ifdef __aarch64__ - uint8x16_t vb1 = vdupq_n_u8(b); - uint8x16_t vg1 = vdupq_n_u8(g); - uint8x16_t vr1 = vdupq_n_u8(r); -#else uint8_t vb_array[8] = {b, b, b, b, b, b, b, b}; uint8_t vg_array[8] = {g, g, g, g, g, g, g, g}; uint8_t vr_array[8] = {r, r, r, r, r, r, r, r}; -#endif int cnt_pro = srcw >> 3; int remain_pro = srcw % 8; int win = srcw * 3; @@ -866,6 +860,9 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { "prfm pldl1keep, [%[inptr2], #128] \n" "prfm pldl1keep, [%[inptr3]] \n" "prfm pldl1keep, [%[inptr3], #128] \n" + "ld1 {v21.8b}, [%[vb]] \n" + "ld1 {v22.8b}, [%[vg]] \n" + "ld1 {v23.8b}, [%[vr]] \n" "1: \n" "ld3 {v0.8b - v2.8b}, [%[inptr0]], #24 \n" // d8 = y0y3y6y9.. d9 = // y1y4y7... @@ -876,20 +873,20 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { "ld3 {v9.8b - v11.8b}, [%[inptr3]], #24 \n" // d8 = y0y3y6y9.. d9 = // y1y4y7... // mul b - "umull v12.8h, v0.8b, %w[vb].8b \n" // v0 * vb - "umull v13.8h, v3.8b, %w[vb].8b \n" // v0 * vb - "umull v14.8h, v6.8b, %w[vb].8b \n" // v0 * vb - "umull v15.8h, v9.8b, %w[vb].8b \n" // v0 * vb + "umull v12.8h, v0.8b, v21.8b \n" // v0 * vb + "umull v13.8h, v3.8b, v21.8b \n" // v0 * vb + "umull v14.8h, v6.8b, v21.8b \n" // v0 * vb + "umull v15.8h, v9.8b, v21.8b \n" // v0 * vb // mul g - "umull v16.8h, v1.8b, %w[vg].8b \n" // v0 * vb - "umull v17.8h, v4.8b, %w[vg].8b \n" // v0 * vb - "umull v18.8h, v7.8b, %w[vg].8b \n" // v0 * vb - "umull v19.8h, v10.8b, %w[vg].8b \n" // v0 * vb + "umull v16.8h, v1.8b, v22.8b \n" // v0 * vb + "umull v17.8h, v4.8b, v22.8b \n" // v0 * vb + "umull v18.8h, v7.8b, v22.8b \n" // v0 * vb + "umull v19.8h, v10.8b, v22.8b \n" // v0 * vb // mul r - "umlal v12.8h, v2.8b, %w[vr].8b \n" // v0 * vb - "umlal v13.8h, v5.8b, %w[vr].8b \n" // v0 * vb - "umlal v14.8h, v8.8b, %w[vr].8b \n" // v0 * vb - "umlal v15.8h, v11.8b, %w[vr].8b \n" // v0 * vb + "umlal v12.8h, v2.8b, v23.8b \n" // v0 * vb + "umlal v13.8h, v5.8b, v23.8b \n" // v0 * vb + "umlal v14.8h, v8.8b, v23.8b \n" // v0 * vb + "umlal v15.8h, v11.8b, v23.8b \n" // v0 * vb // 16->32 "uaddl v0.4s, v16.4h, v12.4h \n" "uaddl2 v1.4s, v16.8h, v12.8h \n" @@ -928,7 +925,7 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { [outr2] "+r"(outr2), [outr3] "+r"(outr3), [cnt] "+r"(cnt) - : [vb] "w"(vb1), [vg] "w"(vg1), [vr] "w"(vr1) + : [vb] "r"(vb_array), [vg] "r"(vg_array), [vr] "r"(vr_array) : "cc", "memory", "v0", @@ -951,7 +948,10 @@ void hwc3_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { "v17", "v18", "v19", - "v20"); + "v20", + "v21", + "v22", + "v23"); #else asm volatile( "pld [%[inptr0]] @ preload a, 64byte\n" @@ -1106,15 +1106,9 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { uint8x8_t vb = vdup_n_u8(b); uint8x8_t vg = vdup_n_u8(g); uint8x8_t vr = vdup_n_u8(r); -#ifdef __aarch64__ - uint8x16_t vb1 = vdupq_n_u8(b); - uint8x16_t vg1 = vdupq_n_u8(g); - uint8x16_t vr1 = vdupq_n_u8(r); -#else uint8_t vb_array[8] = {b, b, b, b, b, b, b, b}; uint8_t vg_array[8] = {g, g, g, g, g, g, g, g}; uint8_t vr_array[8] = {r, r, r, r, r, r, r, r}; -#endif int cnt_pro = srcw >> 3; int remain_pro = srcw % 8; int win = srcw * 4; @@ -1143,6 +1137,9 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { "prfm pldl1keep, [%[inptr2], #128] \n" "prfm pldl1keep, [%[inptr3]] \n" "prfm pldl1keep, [%[inptr3], #128] \n" + "ld1 {v21.8b}, [%[vb]] \n" + "ld1 {v22.8b}, [%[vg]] \n" + "ld1 {v23.8b}, [%[vr]] \n" "1: \n" "ld4 {v0.8b - v3.8b}, [%[inptr0]], #32 \n" // d8 = y0y3y6y9.. d9 = // y1y4y7... @@ -1153,20 +1150,20 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { "ld4 {v12.8b - v15.8b}, [%[inptr3]], #32 \n" // d8 = y0y3y6y9.. d9 = // y1y4y7... // mul b - "umull v13.8h, v0.8b, %w[vb].8b \n" // v0 * vb - "umull v14.8h, v4.8b, %w[vb].8b \n" // v0 * vb - "umull v15.8h, v8.8b, %w[vb].8b \n" // v0 * vb - "umull v16.8h, v12.8b, %w[vb].8b \n" // v0 * vb + "umull v13.8h, v0.8b, v21.8b \n" // v0 * vb + "umull v14.8h, v4.8b, v21.8b \n" // v0 * vb + "umull v15.8h, v8.8b, v21.8b \n" // v0 * vb + "umull v16.8h, v12.8b, v21.8b \n" // v0 * vb // mul g - "umull v17.8h, v1.8b, %w[vg].8b \n" // v0 * vb - "umull v18.8h, v5.8b, %w[vg].8b \n" // v0 * vb - "umull v19.8h, v9.8b, %w[vg].8b \n" // v0 * vb - "umull v20.8h, v13.8b, %w[vg].8b \n" // v0 * vb + "umull v17.8h, v1.8b, v22.8b \n" // v0 * vb + "umull v18.8h, v5.8b, v22.8b \n" // v0 * vb + "umull v19.8h, v9.8b, v22.8b \n" // v0 * vb + "umull v20.8h, v13.8b, v22.8b \n" // v0 * vb // mul r - "umlal v13.8h, v2.8b, %w[vr].8b \n" // v0 * vb - "umlal v14.8h, v6.8b, %w[vr].8b \n" // v0 * vb - "umlal v15.8h, v10.8b, %w[vr].8b \n" // v0 * vb - "umlal v16.8h, v14.8b, %w[vr].8b \n" // v0 * vb + "umlal v13.8h, v2.8b, v23.8b \n" // v0 * vb + "umlal v14.8h, v6.8b, v23.8b \n" // v0 * vb + "umlal v15.8h, v10.8b, v23.8b \n" // v0 * vb + "umlal v16.8h, v14.8b, v23.8b \n" // v0 * vb // 16->32 "uaddl v0.4s, v17.4h, v13.4h \n" "uaddl2 v1.4s, v17.8h, v13.8h \n" @@ -1205,7 +1202,7 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { [outr2] "+r"(outr2), [outr3] "+r"(outr3), [cnt] "+r"(cnt) - : [vb] "w"(vb1), [vg] "w"(vg1), [vr] "w"(vr1) + : [vb] "r"(vb_array), [vg] "r"(vg_array), [vr] "r"(vr_array) : "cc", "memory", "v0", @@ -1228,7 +1225,10 @@ void hwc4_to_hwc1(const uint8_t* src, uint8_t* dst, int srcw, int srch) { "v17", "v18", "v19", - "v20"); + "v20", + "v21", + "v22", + "v23"); #else asm volatile( "pld [%[inptr0]] @ preload a, 64byte\n" diff --git a/lite/utils/cv/image_resize.cc b/lite/utils/cv/image_resize.cc index cd02a2cf4b..e9ff8dbbfc 100644 --- a/lite/utils/cv/image_resize.cc +++ b/lite/utils/cv/image_resize.cc @@ -217,12 +217,12 @@ void resize(const uint8_t* src, "1: \n" "ld1 {v0.8h}, [%[rows0p]], #16 \n" "ld1 {v1.8h}, [%[rows1p]], #16 \n" - "orr v6.16b, %w[_v2].16b, %w[_v2].16b \n" - "orr v7.16b, %w[_v2].16b, %w[_v2].16b \n" - "smull v2.4s, v0.4h, %w[_b0].4h \n" - "smull2 v4.4s, v0.8h, %w[_b0].8h \n" - "smull v3.4s, v1.4h, %w[_b1].4h \n" - "smull2 v5.4s, v1.8h, %w[_b1].8h \n" + "orr v6.16b, %[_v2].16b, %[_v2].16b \n" + "orr v7.16b, %[_v2].16b, %[_v2].16b \n" + "smull v2.4s, v0.4h, %[_b0].4h \n" + "smull2 v4.4s, v0.8h, %[_b0].8h \n" + "smull v3.4s, v1.4h, %[_b1].4h \n" + "smull2 v5.4s, v1.8h, %[_b1].8h \n" "ssra v6.4s, v2.4s, #16 \n" "ssra v7.4s, v4.4s, #16 \n" -- GitLab