From 765a20b67fd6755dce93b5fff90c59bee82d7775 Mon Sep 17 00:00:00 2001 From: NazgulLee Date: Thu, 19 Dec 2019 19:37:49 +0800 Subject: [PATCH] add grid_sampler op, test=develop (#2632) --- mobile/src/operators/grid_sampler_op.cpp | 36 +++++++ mobile/src/operators/grid_sampler_op.h | 35 +++++++ .../cl/cl_kernel/grid_sampler_kernel.cl | 99 +++++++++++++++++++ .../kernel/cl/grid_sampler_kernel.cpp | 66 +++++++++++++ .../operators/kernel/grid_sampler_kernel.h | 28 ++++++ mobile/src/operators/op_param.h | 9 +- 6 files changed, 272 insertions(+), 1 deletion(-) create mode 100644 mobile/src/operators/grid_sampler_op.cpp create mode 100644 mobile/src/operators/grid_sampler_op.h create mode 100644 mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl create mode 100644 mobile/src/operators/kernel/cl/grid_sampler_kernel.cpp create mode 100644 mobile/src/operators/kernel/grid_sampler_kernel.h diff --git a/mobile/src/operators/grid_sampler_op.cpp b/mobile/src/operators/grid_sampler_op.cpp new file mode 100644 index 0000000000..90809f1d4c --- /dev/null +++ b/mobile/src/operators/grid_sampler_op.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRID_SAMPLER_OP + +#include "operators/grid_sampler_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void GridSamplerOp::InferShape() const { + auto x_dim = this->param_.InputX()->dims(); + this->param_.Output()->Resize(x_dim); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(grid_sampler, ops::GridSamplerOp); +#endif + +#endif diff --git a/mobile/src/operators/grid_sampler_op.h b/mobile/src/operators/grid_sampler_op.h new file mode 100644 index 0000000000..9d142b9d47 --- /dev/null +++ b/mobile/src/operators/grid_sampler_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRID_SAMPLER_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/kernel/grid_sampler_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#ifdef GRID_SAMPLER_OP +DECLARE_OPERATOR(GridSampler, GridSamplerParam, GridSamplerKernel); +#endif + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl b/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl new file mode 100644 index 0000000000..e366316e43 --- /dev/null +++ b/mobile/src/operators/kernel/cl/cl_kernel/grid_sampler_kernel.cl @@ -0,0 +1,99 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cl_common.h" + +__kernel void grid_sampler(__private const int out_height, + __private const int out_width, + __read_only image2d_t input, + __read_only image2d_t grid, + __write_only image2d_t output) { + const int out_c = get_global_id(0); + const int out_w = get_global_id(1); + const int out_nh = get_global_id(2) * 4; + const int out_n = out_nh / out_height; + const int out_h = out_nh % out_height; + const sampler_t sampler = + CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + int x_grid = out_h / 4 * 2; + int y_grid = out_n * out_width + out_w; + float4 g1 = read_imagef(grid, sampler, int2(x_grid, y_grid)); + float4 g2 = read_imagef(grid, sampler, int2(x_grid + 1, y_grid)); + + float x = (g1.x + 1) * (out_width - 1) / 2; + float y = (g2.x + 1) * (out_height - 1) / 2; + float x0 = floor(x); + float y0 = floor(y); + int x_p = out_c * out_width + x0; + int y_p = out_n * out_height + y0; + int x_out = out_c * out_width + out_w; + int y_out = out_n * out_height + out_h; + float4 input0 = read_imagef(input, sampler, int2(x_p, y_p)); + float4 input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); + float4 input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); + float4 input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + float4 out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + + input1 * (x - x0) * (y0 + 1 - y) + + input2 * (x0 + 1 - x) * (y - y0) + + input3 * (x - x0) * (y - y0); + write_imageh(output, int2(x_out, y_out), convert_half4(out_val)); + + x = (g1.y + 1) * (out_width - 1) / 2; + y = (g2.y + 1) * (out_height - 1) / 2; + x0 = floor(x); + y0 = floor(y); + x_p = out_c * out_width + x0; + y_p = out_n * out_height + y0; + input0 = read_imagef(input, sampler, int2(x_p, y_p)); + input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); + input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); + input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + + input1 * (x - x0) * (y0 + 1 - y) + + input2 * (x0 + 1 - x) * (y - y0) + + input3 * (x - x0) * (y - y0); + write_imageh(output, int2(x_out, y_out + 1), convert_half4(out_val)); + + x = (g1.z + 1) * (out_width - 1) / 2; + y = (g2.z + 1) * (out_height - 1) / 2; + x0 = floor(x); + y0 = floor(y); + x_p = out_c * out_width + x0; + y_p = out_n * out_height + y0; + input0 = read_imagef(input, sampler, int2(x_p, y_p)); + input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); + input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); + input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + + input1 * (x - x0) * (y0 + 1 - y) + + input2 * (x0 + 1 - x) * (y - y0) + + input3 * (x - x0) * (y - y0); + write_imageh(output, int2(x_out, y_out + 2), convert_half4(out_val)); + + x = (g1.w + 1) * (out_width - 1) / 2; + y = (g2.w + 1) * (out_height - 1) / 2; + x0 = floor(x); + y0 = floor(y); + x_p = out_c * out_width + x0; + y_p = out_n * out_height + y0; + input0 = read_imagef(input, sampler, int2(x_p, y_p)); + input1 = read_imagef(input, sampler, int2(x_p + 1, y_p)); + input2 = read_imagef(input, sampler, int2(x_p, y_p + 1)); + input3 = read_imagef(input, sampler, int2(x_p + 1, y_p + 1)); + out_val = input0 * (x0 + 1 - x) * (y0 + 1 - y) + + input1 * (x - x0) * (y0 + 1 - y) + + input2 * (x0 + 1 - x) * (y - y0) + + input3 * (x - x0) * (y - y0); + write_imageh(output, int2(x_out, y_out + 3), convert_half4(out_val)); +} diff --git a/mobile/src/operators/kernel/cl/grid_sampler_kernel.cpp b/mobile/src/operators/kernel/cl/grid_sampler_kernel.cpp new file mode 100644 index 0000000000..3a20ebd94e --- /dev/null +++ b/mobile/src/operators/kernel/cl/grid_sampler_kernel.cpp @@ -0,0 +1,66 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef GRID_SAMPLER_OP + +#include "operators/kernel/grid_sampler_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool GridSamplerKernel::Init(GridSamplerParam* param) { + this->cl_helper_.AddKernel("grid_sampler", "grid_sampler_kernel.cl"); + return true; +} + +template <> +void GridSamplerKernel::Compute( + const GridSamplerParam& param) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*(param.Output())); + cl_int status; + auto output = param.Output(); + auto input = param.InputX(); + auto grid = param.Grid(); + auto output_image = output->GetCLImage(); + auto input_image = input->GetCLImage(); + auto grid_image = grid->GetCLImage(); + const int out_H = output->dims()[2]; + const int out_W = output->dims()[3]; + + status = clSetKernelArg(kernel, 0, sizeof(cl_int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_int), &out_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &grid_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + + const size_t work_size[3] = {default_work_size[0], default_work_size[1], + default_work_size[2] / 4}; + + status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, + NULL, work_size, NULL, 0, NULL, NULL); + + CL_CHECK_ERRORS(status); +} + +template class GridSamplerKernel; + +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/mobile/src/operators/kernel/grid_sampler_kernel.h b/mobile/src/operators/kernel/grid_sampler_kernel.h new file mode 100644 index 0000000000..bbadb6b54a --- /dev/null +++ b/mobile/src/operators/kernel/grid_sampler_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#ifdef GRID_SAMPLER_OP +DECLARE_KERNEL(GridSampler, GridSamplerParam); +#endif // GRID_SAMPLER_OP + +} // namespace operators +} // namespace paddle_mobile diff --git a/mobile/src/operators/op_param.h b/mobile/src/operators/op_param.h index 01a37f5062..1224ef0693 100644 --- a/mobile/src/operators/op_param.h +++ b/mobile/src/operators/op_param.h @@ -337,6 +337,11 @@ class OpParam { return GetVarValue("Filter", inputs, scope); } + template + static T *GridFrom(const VariableNameMap &inputs, const Scope &scope) { + return GetVarValue("Grid", inputs, scope); + } + template static const T GetAttr(const string &key, const AttributeMap &map) { return ((Attribute)map.at(key)).Get(); @@ -3687,7 +3692,6 @@ class PixelShuffleParam : public OpParam { }; #endif - #ifdef GRID_SAMPLER_OP template class GridSamplerParam : public OpParam { @@ -3700,15 +3704,18 @@ class GridSamplerParam : public OpParam { Scope *scope) : OpParam(inputs, outputs, attrs, scope) { input_x_ = InputXFrom(inputs, *scope); + grid_ = GridFrom(inputs, *scope); output_ = OutputFrom(outputs, *scope); } const GType *InputX() const { return input_x_; } + const GType *Grid() const { return grid_; } GType *Output() const { return output_; } private: GType *input_x_; + GType *grid_; GType *output_; }; #endif -- GitLab