From 9875843e009b6304c47a6da131121936d4c0cdc8 Mon Sep 17 00:00:00 2001 From: yiicy Date: Mon, 23 Dec 2019 11:48:21 +0800 Subject: [PATCH] [ARM] add grid_sampler op and ut, test=develop (#2598) --- lite/core/arena/framework.h | 12 +- lite/kernels/arm/CMakeLists.txt | 1 + lite/kernels/arm/grid_sampler_compute.cc | 202 ++++++++++++++++++ lite/kernels/arm/grid_sampler_compute.h | 40 ++++ lite/operators/CMakeLists.txt | 1 + lite/operators/grid_sampler_op.cc | 64 ++++++ lite/operators/grid_sampler_op.h | 47 ++++ lite/operators/op_params.h | 6 + lite/tests/kernels/CMakeLists.txt | 1 + .../kernels/grid_sampler_compute_test.cc | 172 +++++++++++++++ 10 files changed, 545 insertions(+), 1 deletion(-) create mode 100644 lite/kernels/arm/grid_sampler_compute.cc create mode 100644 lite/kernels/arm/grid_sampler_compute.h create mode 100644 lite/operators/grid_sampler_op.cc create mode 100644 lite/operators/grid_sampler_op.h create mode 100644 lite/tests/kernels/grid_sampler_compute_test.cc diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 05af21bbdb..bd8b3c48fb 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -213,7 +213,17 @@ class Arena { } auto duration = std::chrono::duration_cast( std::chrono::high_resolution_clock::now() - timer); - LOG(INFO) << "average duration: " << duration.count() << " ms"; + + timer = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < times; i++) { + tester_->RunBaseline(tester_->baseline_scope()); + } + auto duration_basic = std::chrono::duration_cast( + std::chrono::high_resolution_clock::now() - timer); + LOG(INFO) << "average lite duration: " << duration.count() << " ms"; + LOG(INFO) << "average basic duration: " << duration_basic.count() << " ms"; + LOG(INFO) << "speed up ratio: lite_speed / basic_speed: " + << static_cast(duration_basic.count()) / duration.count(); } private: diff --git a/lite/kernels/arm/CMakeLists.txt b/lite/kernels/arm/CMakeLists.txt index f543c000f8..ce8b8365a8 100644 --- a/lite/kernels/arm/CMakeLists.txt +++ b/lite/kernels/arm/CMakeLists.txt @@ -49,6 +49,7 @@ add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_ add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) +add_kernel(grid_sampler_compute_arm ARM basic SRCS grid_sampler_compute.cc DEPS ${lite_kernel_deps} math_arm) ## 2.other basic kernels: basic kernels that not used in basic models add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) diff --git a/lite/kernels/arm/grid_sampler_compute.cc b/lite/kernels/arm/grid_sampler_compute.cc new file mode 100644 index 0000000000..d0fc2545a5 --- /dev/null +++ b/lite/kernels/arm/grid_sampler_compute.cc @@ -0,0 +1,202 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/arm/grid_sampler_compute.h" +#include "lite/backends/arm/math/funcs.h" +#include "lite/core/op_registry.h" +#include "lite/core/type_system.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +void GridSamplerCompute::PrepareForRun() {} + +void GridSamplerCompute::Run() { + auto& param = this->Param(); + auto n = param.x->dims()[0]; + auto c = param.x->dims()[1]; + auto h = param.x->dims()[2]; + auto w = param.x->dims()[3]; + const float* in = param.x->data(); + const float* grid = param.grid->data(); + float* out = param.out->mutable_data(); + auto& ctx = this->ctx_->template As(); + const size_t coor_size = n * h * w; + const size_t workspace_size = coor_size * 12 * sizeof(float); + + ctx.ExtendWorkspace(workspace_size); + int32_t* coor_p = ctx.workspace_data(); + float* dis_p = reinterpret_cast(coor_p) + coor_size * 4; + uint32_t* bound_p = reinterpret_cast(dis_p) + coor_size * 4; + + float x_max = static_cast(w - 1); + float y_max = static_cast(h - 1); + float32x4_t vxmax = vdupq_n_f32(x_max); + float32x4_t vymax = vdupq_n_f32(y_max); + float32x4_t vone = vdupq_n_f32(1.f); + float32x4_t vzero = vdupq_n_f32(0.f); + + // compute coor, dis, bound + int i = coor_size; + for (; i > 3; i -= 4) { + float32x4x2_t xy = vld2q_f32(grid); + float32x4_t grid_x = vmulq_n_f32(vaddq_f32(xy.val[0], vone), 0.5 * x_max); + float32x4_t grid_y = vmulq_n_f32(vaddq_f32(xy.val[1], vone), 0.5 * y_max); + grid += 8; + + // compute xw, we, yn, ys + int32x4x4_t vcoor; + vcoor.val[0] = vcvtq_s32_f32(grid_x); + vcoor.val[2] = vcvtq_s32_f32(grid_y); + float32x4_t vxwf = vcvtq_f32_s32(vcoor.val[0]); + float32x4_t vynf = vcvtq_f32_s32(vcoor.val[2]); + float32x4_t vxef = vaddq_f32(vxwf, vone); + float32x4_t vysf = vaddq_f32(vynf, vone); + vcoor.val[1] = vcvtq_s32_f32(vxef); + vcoor.val[3] = vcvtq_s32_f32(vysf); + vst4q_s32(coor_p, vcoor); + coor_p += 16; + + // compute dw, dn ,de, ds + float32x4x4_t vdis; + vdis.val[0] = vsubq_f32(grid_x, vxwf); + vdis.val[2] = vsubq_f32(grid_y, vynf); + vdis.val[1] = vsubq_f32(vxef, grid_x); + vdis.val[3] = vsubq_f32(vysf, grid_y); + vst4q_f32(dis_p, vdis); + dis_p += 16; + + // compute bound + uint32x4x4_t vbound; + uint32x4_t logic_xw = + vorrq_u32(vcltq_f32(vxwf, vzero), vcgtq_f32(vxwf, vxmax)); + uint32x4_t logic_xe = + vorrq_u32(vcltq_f32(vxef, vzero), vcgtq_f32(vxef, vxmax)); + uint32x4_t logic_yn = + vorrq_u32(vcltq_f32(vynf, vzero), vcgtq_f32(vynf, vymax)); + uint32x4_t logic_ys = + vorrq_u32(vcltq_f32(vysf, vzero), vcgtq_f32(vysf, vymax)); + vbound.val[0] = vmvnq_u32(vorrq_u32(logic_xw, logic_yn)); + vbound.val[1] = vmvnq_u32(vorrq_u32(logic_xe, logic_yn)); + vbound.val[2] = vmvnq_u32(vorrq_u32(logic_xw, logic_ys)); + vbound.val[3] = vmvnq_u32(vorrq_u32(logic_xe, logic_ys)); + vst4q_u32(bound_p, vbound); + bound_p += 16; + } + + for (; i > 0; i--) { + float x = grid[0]; + float y = grid[1]; + float grid_x = (x + 1) * 0.5 * x_max; + float grid_y = (y + 1) * 0.5 * y_max; + grid += 2; + + // compute xw, xe, yn, ys + int32_t xw = static_cast(floor(grid_x)); + int32_t xe = xw + 1; + int32_t yn = static_cast(floor(grid_y)); + int32_t ys = yn + 1; + *coor_p++ = xw; + *coor_p++ = xe; + *coor_p++ = yn; + *coor_p++ = ys; + + // compute dw, de, dn, ds + float dw = grid_x - xw; + float de = xe - grid_x; + float dn = grid_y - yn; + float ds = ys - grid_y; + *dis_p++ = dw; + *dis_p++ = de; + *dis_p++ = dn; + *dis_p++ = ds; + + // compute bound + bool logic_xw = (xw < 0.f || xw > x_max); + bool logic_xe = (xe < 0.f || xe > x_max); + bool logic_yn = (yn < 0.f || yn > y_max); + bool logic_ys = (ys < 0.f || ys > y_max); + *bound_p++ = ((logic_xw || logic_yn) ? 0 : 0xffffffff); + *bound_p++ = ((logic_xe || logic_yn) ? 0 : 0xffffffff); + *bound_p++ = ((logic_xw || logic_ys) ? 0 : 0xffffffff); + *bound_p++ = ((logic_xe || logic_ys) ? 0 : 0xffffffff); + } + + size_t cube_size = c * h * w; + size_t spatial_size = h * w; + // compute output + for (int i = 0; i < n; ++i) { + const float* in_n = in + i * cube_size; + float* out_n = out + i * cube_size; + int32_t* coor_n = ctx.workspace_data() + i * spatial_size * 4; + float* dis_n = reinterpret_cast(coor_n) + coor_size * 4; + uint32_t* bound_n = reinterpret_cast(dis_n) + coor_size * 4; +#pragma omp parallel for + for (int j = 0; j < c; ++j) { + int32_t* coor_ptr = coor_n; + float* dis_ptr = dis_n; + uint32_t* bound_ptr = bound_n; + const float* in_c = in_n + j * spatial_size; + float* out_c = out_n + j * spatial_size; + for (int k = 0; k < spatial_size; k++) { + int32x4_t vcoor = vld1q_s32(coor_ptr); + float32x4_t vdis = vld1q_f32(dis_ptr); + int32_t xw = vgetq_lane_s32(vcoor, 0); + int32_t xe = vgetq_lane_s32(vcoor, 1); + int32_t yn = vgetq_lane_s32(vcoor, 2); + int32_t ys = vgetq_lane_s32(vcoor, 3); + + uint32x4_t vbound = vld1q_u32(bound_ptr); + float dw = vgetq_lane_f32(vdis, 0); + float de = vgetq_lane_f32(vdis, 1); + float dn = vgetq_lane_f32(vdis, 2); + float ds = vgetq_lane_f32(vdis, 3); + + uint32_t wnbound = vgetq_lane_u32(vbound, 0); + uint32_t enbound = vgetq_lane_u32(vbound, 1); + uint32_t wsbound = vgetq_lane_u32(vbound, 2); + uint32_t esbound = vgetq_lane_u32(vbound, 3); + + float in_wn = wnbound ? in_c[yn * w + xw] : 0.f; + float in_en = enbound ? in_c[yn * w + xe] : 0.f; + float in_ws = wsbound ? in_c[ys * w + xw] : 0.f; + float in_es = esbound ? in_c[ys * w + xe] : 0.f; + + coor_ptr += 4; + dis_ptr += 4; + bound_ptr += 4; + *out_c++ = + ds * (in_wn * de + in_en * dw) + dn * (in_ws * de + in_es * dw); + } + } + } +} + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(grid_sampler, + kARM, + kFloat, + kNCHW, + paddle::lite::kernels::arm::GridSamplerCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindInput("Grid", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))}) + .Finalize(); diff --git a/lite/kernels/arm/grid_sampler_compute.h b/lite/kernels/arm/grid_sampler_compute.h new file mode 100644 index 0000000000..5cc78a3fd0 --- /dev/null +++ b/lite/kernels/arm/grid_sampler_compute.h @@ -0,0 +1,40 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace arm { + +class GridSamplerCompute : public KernelLite { + public: + using param_t = operators::GridSamplerParam; + + void PrepareForRun() override; + + void Run() override; + + virtual ~GridSamplerCompute() = default; + + private: +}; + +} // namespace arm +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 34b364ae39..06752b4e19 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -49,6 +49,7 @@ add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS}) add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS}) add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS}) add_operator(subgraph_op basic SRCS subgraph_op.cc DEPS ${op_DEPS}) +add_operator(grid_sampler_op basic SRCS grid_sampler_op.cc DEPS ${op_DEPS}) # 2.basic ops not used in basic models add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/grid_sampler_op.cc b/lite/operators/grid_sampler_op.cc new file mode 100644 index 0000000000..2b13d17da7 --- /dev/null +++ b/lite/operators/grid_sampler_op.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/grid_sampler_op.h" +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool GridSamplerOp::CheckShape() const { + CHECK_OR_FALSE(param_.x); + CHECK_OR_FALSE(param_.out); + CHECK_OR_FALSE(param_.grid); + auto x_dims = param_.x->dims(); + auto grid_dims = param_.grid->dims(); + + CHECK_EQ(x_dims.size(), 4UL) << "Input must have 4 dimensions."; + CHECK_EQ(grid_dims.size(), 4UL) << "Grid must have 4 dimensions."; + CHECK_EQ(grid_dims[0], x_dims[0]) + << "Input(X) dims[0] and Input(Grid) dims[0] should be equal."; + CHECK_EQ(grid_dims[1], x_dims[2]) + << "Input(X) dims[2] and Input(Grid) dims[1] should be equal."; + CHECK_EQ(grid_dims[2], x_dims[3]) + << "Input(X) dims[3] and Input(Grid) dims[2] should be equal."; + + return true; +} + +bool GridSamplerOp::InferShape() const { + auto x_dims = param_.x->dims(); + param_.out->Resize(x_dims); + return true; +} + +bool GridSamplerOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { + param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable(); + param_.grid = + scope->FindVar(op_desc.Input("Grid").front())->GetMutable(); + param_.out = + scope->FindVar(op_desc.Output("Output").front())->GetMutable(); + return true; +} + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ + +REGISTER_LITE_OP(grid_sampler, paddle::lite::operators::GridSamplerOp); diff --git a/lite/operators/grid_sampler_op.h b/lite/operators/grid_sampler_op.h new file mode 100644 index 0000000000..035e1b8345 --- /dev/null +++ b/lite/operators/grid_sampler_op.h @@ -0,0 +1,47 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/op_lite.h" +#include "lite/core/scope.h" +#include "lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class GridSamplerOp : public OpLite { + public: + GridSamplerOp() {} + + explicit GridSamplerOp(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "grid_sampler"; } + + private: + mutable GridSamplerParam param_; +}; + +} /* namespace operators */ +} /* namespace lite */ +} /* namespace paddle */ diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index bd2ba937ea..f1a2eb1e52 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -1118,6 +1118,12 @@ struct InstanceNormParam { lite::Tensor* saved_variance{}; float epsilon; }; +/// --------------------- grid sampler operators -------------------- +struct GridSamplerParam { + lite::Tensor* x{}; + lite::Tensor* out{}; + lite::Tensor* grid{}; +}; } // namespace operators } // namespace lite diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index dbe85f17f7..0d7890b43c 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -15,6 +15,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/grid_sampler_compute_test.cc b/lite/tests/kernels/grid_sampler_compute_test.cc new file mode 100644 index 0000000000..ec28ff9fa3 --- /dev/null +++ b/lite/tests/kernels/grid_sampler_compute_test.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +class GridSamplerComputeTest : public arena::TestCase { + protected: + // common attributes for this op. + std::string input_ = "x"; + std::string output_ = "y"; + std::string grid_ = "grid"; + + DDim dims_{{4, 5, 19, 19}}; + + public: + GridSamplerComputeTest(const Place& place, + const std::string& alias, + DDim dims) + : TestCase(place, alias), dims_(dims) {} + + void RunBaseline(Scope* scope) override { + auto x = scope->FindTensor(input_); + auto grid = scope->FindTensor(grid_); + auto out = scope->NewTensor(output_); + CHECK(out); + out->Resize(dims_); + + const float* x_data = x->data(); + const float* grid_data = grid->data(); + float* out_data = out->mutable_data(); + + int num = x->dims()[0]; + int channel = x->dims()[1]; + int height = x->dims()[2]; + int width = x->dims()[3]; + int spatial_size = height * width; + + auto inbound = [](int x, int y, float x_max, float y_max) { + if (x < 0 || x > x_max || y < 0 || y > y_max) { + return false; + } + return true; + }; + + for (int n = 0; n < num; ++n) { + const float* x_n = x_data + n * channel * height * width; + float* out_n = out_data + n * channel * height * width; + const float* grid_n = grid_data + n * height * width * 2; + for (int c = 0; c < channel; ++c) { + const float* x_c = x_n + c * spatial_size; + float* out_c = out_n + c * spatial_size; + for (int s = 0; s < spatial_size; ++s) { + float x = grid_n[s * 2]; + float y = grid_n[s * 2 + 1]; + float xwf = (x + 1.f) * 0.5 * (width - 1); + float ynf = (y + 1.f) * 0.5 * (height - 1); + int xw = floor(xwf); + int xe = xw + 1; + int yn = floor(ynf); + int ys = yn + 1; + + float dw = xwf - xw; + float de = xe - xwf; + float dn = ynf - yn; + float ds = ys - ynf; + + float wn = inbound(xw, + yn, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[yn * width + xw] + : 0.f; + float en = inbound(xe, + yn, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[yn * width + xe] + : 0.f; + float ws = inbound(xw, + ys, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[ys * width + xw] + : 0.f; + float es = inbound(xe, + ys, + static_cast(width - 1), + static_cast(height - 1)) + ? x_c[ys * width + xe] + : 0.f; + + out_c[s] = wn * de * ds + en * dw * ds + ws * de * dn + es * dw * dn; + } + } + } + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType("grid_sampler"); + op_desc->SetInput("X", {input_}); + op_desc->SetInput("Grid", {grid_}); + op_desc->SetOutput("Output", {output_}); + } + + void PrepareData() override { + std::vector din(dims_.production()); + fill_data_rand(din.data(), -1.f, 1.f, dims_.production()); + + DDim gird_dims{{dims_[0], dims_[2], dims_[3], 2}}; + std::vector grid(gird_dims.production()); + fill_data_rand(grid.data(), -1.f, 1.f, gird_dims.production()); + + SetCommonTensor(input_, dims_, din.data()); + SetCommonTensor(grid_, gird_dims, grid.data()); + } +}; + +void test_grid_sampler(Place place) { + for (auto& n : {1, 13}) { + for (auto& c : {1, 3, 8}) { + for (auto& h : {1, 3, 8, 64}) { + for (auto& w : {2, 4, 9, 63}) { + DDim dim_in({n, c, h, w}); + std::unique_ptr tester( + new GridSamplerComputeTest(place, "def", dim_in)); +#ifdef LITE_WITH_ARM + auto& ctx = tester->context()->As(); + ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 1); +#endif + arena::Arena arena(std::move(tester), place, 6e-5); + LOG(INFO) << "run n: " << n << ", c: " << c << ", h: " << h + << ", w: " << w; + if (!arena.TestPrecision()) { + LOG(ERROR) << "No Pass!!"; + return; + } + // if you want to test this op performance, uncomment the following + // line + // arena.TestPerformance(); + } + } + } + } +} + +TEST(GridSampler, precision) { +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + test_grid_sampler(place); +#endif +} + +} // namespace lite +} // namespace paddle -- GitLab