diff --git a/lite/backends/x86/math/CMakeLists.txt b/lite/backends/x86/math/CMakeLists.txt index b5262efa4e8ca3fbfa3076fb9a5eb6fe1993ccb2..09bcfc67e413d6e361b3c42448ca841c1b3aa847 100644 --- a/lite/backends/x86/math/CMakeLists.txt +++ b/lite/backends/x86/math/CMakeLists.txt @@ -63,3 +63,4 @@ math_library(search_fc DEPS blas dynload_mklml) # cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) math_library(box_coder DEPS math_function) math_library(prior_box DEPS math_function) +math_library(interpolate DEPS math_function) diff --git a/lite/backends/x86/math/interpolate.cc b/lite/backends/x86/math/interpolate.cc new file mode 100644 index 0000000000000000000000000000000000000000..14f8a2d055ad727c6e36aba7f5f25b21666c659b --- /dev/null +++ b/lite/backends/x86/math/interpolate.cc @@ -0,0 +1,266 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/x86/math/interpolate.h" +#include +#include +#include "lite/backends/x86/math/math_function.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +void bilinear_interp(const float* input_data, + float* output_data, + const float ratio_h, + const float ratio_w, + const int in_h, + const int in_w, + const int n, + const int c, + const int out_h, + const int out_w, + const bool align_corners, + const bool align_mode) { + bool align_flag = (align_mode == 0 && !align_corners); + + std::vector vy_n, vy_s; + std::vector vd_n, vd_s; + vy_n.reserve(out_h); + vy_s.reserve(out_h); + vd_n.reserve(out_h); + vd_s.reserve(out_h); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int k = 0; k < out_h; k++) { + int y_n = align_flag ? static_cast(ratio_h * (k + 0.5) - 0.5) + : static_cast(ratio_h * k); + y_n = (y_n > 0) ? y_n : 0; + int y_s = (y_n + 1) < (in_h - 1) ? (y_n + 1) : (in_h - 1); + float idx_src_y = ratio_h * (k + 0.5) - 0.5; + idx_src_y = (idx_src_y > 0) ? idx_src_y : 0; + float d_n = align_flag ? idx_src_y - y_n : ratio_h * k - y_n; + float d_s = 1.f - d_n; + { + vy_n[k] = y_n; + vy_s[k] = y_s; + vd_n[k] = d_n; + vd_s[k] = d_s; + } + } + + std::vector vx_w, vx_e; + std::vector vd_w, vd_e; + vx_w.reserve(out_w); + vx_e.reserve(out_w); + vd_w.reserve(out_w); + vd_e.reserve(out_w); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int l = 0; l < out_w; l++) { + int x_w = (align_mode == 0 && !align_corners) + ? static_cast(ratio_w * (l + 0.5) - 0.5) + : static_cast(ratio_w * l); + x_w = (x_w > 0) ? x_w : 0; + int x_e = (x_w + 1) < (in_w - 1) ? (x_w + 1) : (in_w - 1); + float idx_src_x = ratio_w * (l + 0.5) - 0.5; + idx_src_x = (idx_src_x > 0) ? idx_src_x : 0; + float d_w = align_flag ? idx_src_x - x_w : ratio_w * l - x_w; + float d_e = 1.f - d_w; + { + vx_w[l] = x_w; + vx_e[l] = x_e; + vd_w[l] = d_w; + vd_e[l] = d_e; + } + } + + int total_count = n * c; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int i = 0; i < total_count; i++) { + for (int h = 0; h < out_h; h++) { + for (int w = 0; w < out_w; w++) { + // bilinear interpolation + const float* input_data_ptr = input_data + i * in_h * in_w; + float* output_data_ptr = + output_data + i * out_h * out_w + h * out_w + w; + *output_data_ptr = + input_data_ptr[vy_n[h] * in_w + vx_w[w]] * vd_s[h] * vd_e[w] + + input_data_ptr[vy_s[h] * in_w + vx_w[w]] * vd_n[h] * vd_e[w] + + input_data_ptr[vy_n[h] * in_w + vx_e[w]] * vd_s[h] * vd_w[w] + + input_data_ptr[vy_s[h] * in_w + vx_e[w]] * vd_n[h] * vd_w[w]; + } + } + } +} +void nearest_interp(const float* input_data, + float* output_data, + const float ratio_h, + const float ratio_w, + const int n, + const int c, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const bool align_corners) { + int total_count = n * c; + if (align_corners) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int i = 0; i < total_count; ++i) { + for (int h = 0; h < out_h; ++h) { + for (int w = 0; w < out_w; ++w) { + const float* input_data_ptr = input_data + i * in_h * in_w; + float* output_data_ptr = + output_data + i * out_h * out_w + h * out_w + w; + int near_y = static_cast(ratio_h * h + 0.5); + int near_x = static_cast(ratio_w * w + 0.5); + *output_data_ptr = input_data_ptr[near_y * in_w + near_x]; + } + } + } + } else { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int i = 0; i < total_count; ++i) { + for (int h = 0; h < out_h; ++h) { + for (int w = 0; w < out_w; ++w) { + const float* input_data_ptr = input_data + i * in_h * in_w; + float* output_data_ptr = + output_data + i * out_h * out_w + h * out_w + w; + int near_y = static_cast(ratio_h * h); + int near_x = static_cast(ratio_w * w); + *output_data_ptr = input_data_ptr[near_y * in_w + near_x]; + } + } + } + } +} + +inline std::vector get_new_shape( + std::vector list_new_shape_tensor) { + // get tensor from + std::vector vec_new_shape; + for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) { + auto tensor = list_new_shape_tensor[i]; + vec_new_shape.push_back(static_cast(*tensor->data())); + } + + return vec_new_shape; +} + +template +inline std::vector get_new_data_from_tensor(const Tensor* new_data_tensor) { + std::vector vec_new_data; + auto* new_data = new_data_tensor->data(); + lite::Tensor cpu_starts_tensor; + vec_new_data = + std::vector(new_data, new_data + new_data_tensor->dims().production()); + return vec_new_data; +} + +void interpolate(lite::Tensor* input, + lite::Tensor* out_size, + std::vector list_new_size_tensor, + lite::Tensor* scale_tensor, + lite::Tensor* output, + float scale, + int out_h, + int out_w, + const int align_mode, + const bool align_corners, + const std::string interpolate_type) { + // format NCHW + int n = input->dims()[0]; + int c = input->dims()[1]; + int in_h = input->dims()[2]; + int in_w = input->dims()[3]; + if (list_new_size_tensor.size() > 0) { + // have size tensor + auto new_size = get_new_shape(list_new_size_tensor); + out_h = new_size[0]; + out_w = new_size[1]; + } else { + if (scale_tensor != nullptr) { + auto scale_data = get_new_data_from_tensor(scale_tensor); + scale = scale_data[0]; + } + if (scale > 0) { + out_h = static_cast(in_h * scale); + out_w = static_cast(in_w * scale); + } + if (out_size != nullptr) { + auto out_size_data = get_new_data_from_tensor(out_size); + out_h = out_size_data[0]; + out_w = out_size_data[1]; + } + } + output->Resize({n, c, out_h, out_w}); + + float ratio_h = 0.f; + float ratio_w = 0.f; + if (out_h > 1) { + ratio_h = (align_corners) ? static_cast(in_h - 1) / (out_h - 1) + : static_cast(in_h) / out_h; + } + if (out_w > 1) { + ratio_w = (align_corners) ? static_cast(in_w - 1) / (out_w - 1) + : static_cast(in_w) / out_w; + } + + const float* input_data = input->data(); + float* output_data = output->mutable_data(); + if ("Bilinear" == interpolate_type) { + bilinear_interp(input_data, + output_data, + ratio_h, + ratio_w, + in_h, + in_w, + n, + c, + out_h, + out_w, + align_corners, + align_mode); + } else if ("Nearest" == interpolate_type) { + nearest_interp(input_data, + output_data, + ratio_h, + ratio_w, + n, + c, + in_h, + in_w, + out_h, + out_w, + align_corners); + } else { + LOG(FATAL) << "Not supported interpolate_type: " << interpolate_type; + } +} + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/interpolate.h b/lite/backends/x86/math/interpolate.h new file mode 100644 index 0000000000000000000000000000000000000000..d92fea958bb6f96e41d7110e98a344532c3be354 --- /dev/null +++ b/lite/backends/x86/math/interpolate.h @@ -0,0 +1,65 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "lite/core/tensor.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +void bilinear_interp(const float* input_data, + float* output_data, + const float ratio_h, + const float ratio_w, + const int in_h, + const int in_w, + const int n, + const int c, + const int out_h, + const int out_w, + const bool align_corners, + const bool align_mode); + +void nearest_interp(const float* input_data, + float* output_data, + const float ratio_h, + const float ratio_w, + const int n, + const int c, + const int in_h, + const int in_w, + const int out_h, + const int out_w, + const bool align_corners); + +void interpolate(lite::Tensor* input, + lite::Tensor* out_size, + std::vector list_new_size_tensor, + lite::Tensor* scale_tensor, + lite::Tensor* output, + float scale, + int out_h, + int out_w, + const int align_mode, + const bool align_corners, + const std::string interpolate_type); + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index c98f789911fde831a843a5845953f0b863d118f1..2836890178ba9c506e1a4962e82d7d696af8fe7b 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -70,6 +70,7 @@ add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas) add_kernel(box_coder_compute_x86 X86 basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} box_coder) add_kernel(density_prior_box_compute_x86 X86 basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} prior_box) +add_kernel(interpolate_compute_x86 X86 basic SRCS interpolate_compute.cc DEPS ${lite_kernel_deps} interpolate) lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) diff --git a/lite/kernels/x86/interpolate_compute.cc b/lite/kernels/x86/interpolate_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf799114d49aba639a2504f6f626b93f8fe305e6 --- /dev/null +++ b/lite/kernels/x86/interpolate_compute.cc @@ -0,0 +1,120 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/interpolate_compute.h" +#include +#include +#include "lite/backends/x86/math/interpolate.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +void BilinearInterpCompute::Run() { + auto& param = Param(); + // required input + lite::Tensor* X = param.X; + // optionla inputs + lite::Tensor* OutSize = param.OutSize; + auto SizeTensor = param.SizeTensor; + auto Scale = param.Scale; + // output + lite::Tensor* Out = param.Out; + // optional attributes + float scale = param.scale; + int out_w = param.out_w; + int out_h = param.out_h; + int align_mode = param.align_mode; + // required attributes + bool align_corners = param.align_corners; + std::string interp_method = "Bilinear"; + lite::x86::math::interpolate(X, + OutSize, + SizeTensor, + Scale, + Out, + scale, + out_h, + out_w, + align_mode, + align_corners, + interp_method); +} + +void NearestInterpCompute::Run() { + auto& param = Param(); + // required input + lite::Tensor* X = param.X; + // optionla inputs + lite::Tensor* OutSize = param.OutSize; + auto SizeTensor = param.SizeTensor; + auto Scale = param.Scale; + // output + lite::Tensor* Out = param.Out; + // optional attributes + float scale = param.scale; + int out_w = param.out_w; + int out_h = param.out_h; + int align_mode = param.align_mode; + // required attributes + bool align_corners = param.align_corners; + std::string interp_method = "Nearest"; + lite::x86::math::interpolate(X, + OutSize, + SizeTensor, + Scale, + Out, + scale, + out_h, + out_w, + align_mode, + align_corners, + interp_method); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(bilinear_interp, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::BilinearInterpCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindInput("SizeTensor", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); + +REGISTER_LITE_KERNEL(nearest_interp, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::NearestInterpCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("OutSize", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindInput("SizeTensor", + {LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))}) + .BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/interpolate_compute.h b/lite/kernels/x86/interpolate_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..8efb9f97c9f9e28712f44bbe5ca80a6053d884e0 --- /dev/null +++ b/lite/kernels/x86/interpolate_compute.h @@ -0,0 +1,43 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +class BilinearInterpCompute + : public KernelLite { + public: + void Run() override; + + virtual ~BilinearInterpCompute() = default; +}; + +class NearestInterpCompute + : public KernelLite { + public: + void Run() override; + + virtual ~NearestInterpCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/interp_compute_test.cc b/lite/tests/kernels/interp_compute_test.cc index f512808632f3d99153c1ca93c94e3edc679b9c96..a76eba928dae674520746c6a6c783ae7cf769ccf 100644 --- a/lite/tests/kernels/interp_compute_test.cc +++ b/lite/tests/kernels/interp_compute_test.cc @@ -453,6 +453,8 @@ TEST(Interp, precision) { abs_error = 1e-2; // precision_mode default is force_fp16 #elif defined(LITE_WITH_ARM) place = TARGET(kARM); +#elif defined(LITE_WITH_X86) + place = TARGET(kX86); #else return; #endif