diff --git a/lite/api/cxx_api_impl.cc b/lite/api/cxx_api_impl.cc index a41c1d0a30508536a77d6013ae401a4beedbd9a5..0b5b9ad94c47a3d97492cd5b91618b184c9ef122 100644 --- a/lite/api/cxx_api_impl.cc +++ b/lite/api/cxx_api_impl.cc @@ -96,7 +96,7 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) { config.subgraph_model_cache_dir()); #endif #if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \ - !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) + !(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__) int num_threads = config.x86_math_library_num_threads(); int real_num_threads = num_threads > 1 ? num_threads : 1; paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads); diff --git a/lite/backends/x86/math/CMakeLists.txt b/lite/backends/x86/math/CMakeLists.txt index a89107632341cf063ac3166aa9890ff383e3383f..b5262efa4e8ca3fbfa3076fb9a5eb6fe1993ccb2 100644 --- a/lite/backends/x86/math/CMakeLists.txt +++ b/lite/backends/x86/math/CMakeLists.txt @@ -61,3 +61,5 @@ math_library(search_fc DEPS blas dynload_mklml) # cc_test(beam_search_test SRCS beam_search_test.cc DEPS beam_search) # cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) # cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) +math_library(box_coder DEPS math_function) +math_library(prior_box DEPS math_function) diff --git a/lite/backends/x86/math/box_coder.cc b/lite/backends/x86/math/box_coder.cc new file mode 100644 index 0000000000000000000000000000000000000000..efe3c14fdad1ab529262731316c048e4238cd223 --- /dev/null +++ b/lite/backends/x86/math/box_coder.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/x86/math/box_coder.h" +#include + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +void encode_center_size(const int64_t row, // N + const int64_t col, // M + const int64_t len, // 4 + const float* target_box_data, + const float* prior_box_data, + const float* prior_box_var_data, + const bool normalized, + const std::vector variance, + float* output) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + size_t offset = i * col * len + j * len; + float prior_box_width = prior_box_data[j * len + 2] - + prior_box_data[j * len] + (normalized == false); + float prior_box_height = prior_box_data[j * len + 3] - + prior_box_data[j * len + 1] + + (normalized == false); + float prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2; + float prior_box_center_y = + prior_box_data[j * len + 1] + prior_box_height / 2; + + float target_box_center_x = + (target_box_data[i * len + 2] + target_box_data[i * len]) / 2; + float target_box_center_y = + (target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2; + float target_box_width = target_box_data[i * len + 2] - + target_box_data[i * len] + (normalized == false); + float target_box_height = target_box_data[i * len + 3] - + target_box_data[i * len + 1] + + (normalized == false); + + output[offset] = + (target_box_center_x - prior_box_center_x) / prior_box_width; + output[offset + 1] = + (target_box_center_y - prior_box_center_y) / prior_box_height; + output[offset + 2] = + std::log(std::fabs(target_box_width / prior_box_width)); + output[offset + 3] = + std::log(std::fabs(target_box_height / prior_box_height)); + } + } + + if (prior_box_var_data) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + for (int64_t k = 0; k < len; ++k) { + size_t offset = i * col * len + j * len; + int prior_var_offset = j * len; + output[offset + k] /= prior_box_var_data[prior_var_offset + k]; + } + } + } + } else if (!(variance.empty())) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + for (int64_t k = 0; k < len; ++k) { + size_t offset = i * col * len + j * len; + output[offset + k] /= variance[k]; + } + } + } + } +} + +void decode_center_size(const int axis, + const int var_size, + const int64_t row, + const int64_t col, + const int64_t len, + const float* target_box_data, + const float* prior_box_data, + const float* prior_box_var_data, + const bool normalized, + const std::vector variance, + float* output) { +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int64_t i = 0; i < row; ++i) { + for (int64_t j = 0; j < col; ++j) { + float var_data[4] = {1., 1., 1., 1.}; + float* var_ptr = var_data; + size_t offset = i * col * len + j * len; + int prior_box_offset = axis == 0 ? j * len : i * len; + + float prior_box_width = prior_box_data[prior_box_offset + 2] - + prior_box_data[prior_box_offset] + + (normalized == false); + float prior_box_height = prior_box_data[prior_box_offset + 3] - + prior_box_data[prior_box_offset + 1] + + (normalized == false); + float prior_box_center_x = + prior_box_data[prior_box_offset] + prior_box_width / 2; + float prior_box_center_y = + prior_box_data[prior_box_offset + 1] + prior_box_height / 2; + + float target_box_center_x = 0, target_box_center_y = 0; + float target_box_width = 0, target_box_height = 0; + int prior_var_offset = axis == 0 ? j * len : i * len; + if (var_size == 2) { + std::memcpy( + var_ptr, prior_box_var_data + prior_var_offset, 4 * sizeof(float)); + } else if (var_size == 1) { + var_ptr = const_cast(variance.data()); + } + float box_var_x = *var_ptr; + float box_var_y = *(var_ptr + 1); + float box_var_w = *(var_ptr + 2); + float box_var_h = *(var_ptr + 3); + + target_box_center_x = + box_var_x * target_box_data[offset] * prior_box_width + + prior_box_center_x; + target_box_center_y = + box_var_y * target_box_data[offset + 1] * prior_box_height + + prior_box_center_y; + target_box_width = + std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width; + target_box_height = + std::exp(box_var_h * target_box_data[offset + 3]) * prior_box_height; + + output[offset] = target_box_center_x - target_box_width / 2; + output[offset + 1] = target_box_center_y - target_box_height / 2; + output[offset + 2] = + target_box_center_x + target_box_width / 2 - (normalized == false); + output[offset + 3] = + target_box_center_y + target_box_height / 2 - (normalized == false); + } + } +} + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/box_coder.h b/lite/backends/x86/math/box_coder.h new file mode 100644 index 0000000000000000000000000000000000000000..fc31f888ab7ed281533e187ca8b51344f150662a --- /dev/null +++ b/lite/backends/x86/math/box_coder.h @@ -0,0 +1,50 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/x86/math/math_function.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +void encode_center_size(const int64_t row, + const int64_t col, + const int64_t len, + const float* target_box_data, + const float* prior_box_data, + const float* prior_box_var_data, + const bool normalized, + const std::vector variance, + float* output); + +void decode_center_size(const int axis, + const int var_size, + const int64_t row, + const int64_t col, + const int64_t len, + const float* target_box_data, + const float* prior_box_data, + const float* prior_box_var_data, + const bool normalized, + const std::vector variance, + float* output); + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/prior_box.cc b/lite/backends/x86/math/prior_box.cc new file mode 100644 index 0000000000000000000000000000000000000000..159838895ad8145e4db81f5f3701ec8ddb2611a4 --- /dev/null +++ b/lite/backends/x86/math/prior_box.cc @@ -0,0 +1,118 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "lite/backends/x86/math/prior_box.h" +#include +#include + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +void density_prior_box(const int64_t img_width, + const int64_t img_height, + const int64_t feature_width, + const int64_t feature_height, + const float* input_data, + const float* image_data, + const bool clip, + const std::vector variances, + const std::vector fixed_sizes, + const std::vector fixed_ratios, + const std::vector densities, + const float step_width, + const float step_height, + const float offset, + const int num_priors, + float* boxes_data, + float* vars_data) { + int step_average = static_cast((step_width + step_height) * 0.5); + + std::vector sqrt_fixed_ratios; +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (size_t i = 0; i < fixed_ratios.size(); i++) { + sqrt_fixed_ratios.push_back(sqrt(fixed_ratios[i])); + } + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(2) +#endif + for (int64_t h = 0; h < feature_height; ++h) { + for (int64_t w = 0; w < feature_width; ++w) { + float center_x = (w + offset) * step_width; + float center_y = (h + offset) * step_height; + int64_t offset = (h * feature_width + w) * num_priors * 4; + // Generate density prior boxes with fixed sizes. + for (size_t s = 0; s < fixed_sizes.size(); ++s) { + auto fixed_size = fixed_sizes[s]; + int density = densities[s]; + int shift = step_average / density; + // Generate density prior boxes with fixed ratios. + for (size_t r = 0; r < fixed_ratios.size(); ++r) { + float box_width_ratio = fixed_size * sqrt_fixed_ratios[r]; + float box_height_ratio = fixed_size / sqrt_fixed_ratios[r]; + float density_center_x = center_x - step_average / 2. + shift / 2.; + float density_center_y = center_y - step_average / 2. + shift / 2.; + for (int di = 0; di < density; ++di) { + for (int dj = 0; dj < density; ++dj) { + float center_x_temp = density_center_x + dj * shift; + float center_y_temp = density_center_y + di * shift; + boxes_data[offset++] = std::max( + (center_x_temp - box_width_ratio / 2.) / img_width, 0.); + boxes_data[offset++] = std::max( + (center_y_temp - box_height_ratio / 2.) / img_height, 0.); + boxes_data[offset++] = std::min( + (center_x_temp + box_width_ratio / 2.) / img_width, 1.); + boxes_data[offset++] = std::min( + (center_y_temp + box_height_ratio / 2.) / img_height, 1.); + } + } + } + } + } + } + //! clip the prior's coordinate such that it is within [0, 1] + if (clip) { + int channel_size = feature_height * feature_width * num_priors * 4; +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int d = 0; d < channel_size; ++d) { + boxes_data[d] = std::min(std::max(boxes_data[d], 0.f), 1.f); + } + } +//! set the variance. +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for collapse(3) +#endif + for (int h = 0; h < feature_height; ++h) { + for (int w = 0; w < feature_width; ++w) { + for (int i = 0; i < num_priors; ++i) { + int idx = ((h * feature_width + w) * num_priors + i) * 4; + vars_data[idx++] = variances[0]; + vars_data[idx++] = variances[1]; + vars_data[idx++] = variances[2]; + vars_data[idx++] = variances[3]; + } + } + } +} + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/backends/x86/math/prior_box.h b/lite/backends/x86/math/prior_box.h new file mode 100644 index 0000000000000000000000000000000000000000..6b090551a014a8019e38f5fdcede38b86bfab720 --- /dev/null +++ b/lite/backends/x86/math/prior_box.h @@ -0,0 +1,46 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "lite/backends/x86/math/math_function.h" + +namespace paddle { +namespace lite { +namespace x86 { +namespace math { + +void density_prior_box(const int64_t img_width, + const int64_t img_height, + const int64_t feature_width, + const int64_t feature_height, + const float* input_data, + const float* image_data, + const bool clip, + const std::vector variances, + const std::vector fixed_sizes, + const std::vector fixed_ratios, + const std::vector densities, + const float step_width, + const float step_height, + const float offset, + const int num_priors, + float* boxes_data, + float* vars_data); + +} // namespace math +} // namespace x86 +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc index da42d6d0c79a2a7975eacca7095fedababac6d89..4840a625c7551e96fa5f3ae03585bedf9a85c303 100644 --- a/lite/core/mir/fusion/quant_dequant_fuse_pass.cc +++ b/lite/core/mir/fusion/quant_dequant_fuse_pass.cc @@ -61,5 +61,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr& graph) { REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass, paddle::lite::mir::QuantDequantFusePass) - .BindTargets({TARGET(kAny)}) - .BindKernel("calib"); + .BindTargets({TARGET(kAny)}); diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 521fbb6b24dccb27ce79369ddc631097434105e5..c98f789911fde831a843a5845953f0b863d118f1 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -68,6 +68,8 @@ add_kernel(sequence_topk_avg_pooling_compute_x86 X86 basic SRCS sequence_topk_av add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps} search_fc) add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas) +add_kernel(box_coder_compute_x86 X86 basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} box_coder) +add_kernel(density_prior_box_compute_x86 X86 basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} prior_box) lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86) lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86) diff --git a/lite/kernels/x86/box_coder_compute.cc b/lite/kernels/x86/box_coder_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..db58bf01cbc3f0ca5ea1fa20fbc205a4076eafa8 --- /dev/null +++ b/lite/kernels/x86/box_coder_compute.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/box_coder_compute.h" +#include +#include +#include "lite/backends/x86/math/box_coder.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +void BoxCoderCompute::Run() { + auto& param = *param_.get_mutable(); + // required inputs + auto* prior_box = param.prior_box; // M x 4 => M x [xmin, ymin, xmax, ymax] + auto* target_box = param.target_box; // encode_center_size => N x 4; + // decode_center_size => N x M x 4 + // optional input + auto* prior_box_var = param.prior_box_var; // M x 4 or 4 + // output + auto* output_box = param.proposals; // N x M x 4 + // required attributes + std::string code_type = param.code_type; + bool normalized = param.box_normalized; + // optional attributes + std::vector variance = param.variance; + const int axis = param.axis; + + auto row = target_box->dims()[0]; // N + auto col = prior_box->dims()[0]; // M + if (code_type == "decode_center_size") { // same as target_box + col = target_box->dims()[1]; + } + auto len = prior_box->dims()[1]; // 4 + output_box->Resize({row, col, len}); // N x M x 4 + auto* output = output_box->mutable_data(); + + const float* target_box_data = target_box->data(); + const float* prior_box_data = prior_box->data(); + const float* prior_box_var_data = + prior_box_var ? prior_box_var->data() : nullptr; + + if (code_type == "encode_center_size") { + lite::x86::math::encode_center_size(row, + col, + len, + target_box_data, + prior_box_data, + prior_box_var_data, + normalized, + variance, + output); + } else if (code_type == "decode_center_size") { + int var_size = 0; + if (prior_box_var) { + var_size = 2; + } else if (!(variance.empty())) { + var_size = 1; + } + lite::x86::math::decode_center_size(axis, + var_size, + row, + col, + len, + target_box_data, + prior_box_data, + prior_box_var_data, + normalized, + variance, + output); + } else { + LOG(FATAL) << "box_coder don't support this code_type: " << code_type; + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(box_coder, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::BoxCoderCompute, + def) + .BindInput("PriorBox", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("PriorBoxVar", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("TargetBox", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("OutputBox", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/box_coder_compute.h b/lite/kernels/x86/box_coder_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..34c655bf4bacf537415758651d8c0d8182b047f4 --- /dev/null +++ b/lite/kernels/x86/box_coder_compute.h @@ -0,0 +1,36 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +class BoxCoderCompute : public KernelLite { + public: + using param_t = operators::BoxCoderParam; + + void Run() override; + + virtual ~BoxCoderCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/density_prior_box_compute.cc b/lite/kernels/x86/density_prior_box_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f76e20bbf68c1dc6ff10fb15e2c125629dbe83c --- /dev/null +++ b/lite/kernels/x86/density_prior_box_compute.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/density_prior_box_compute.h" +#include +#include +#include "lite/backends/x86/math/prior_box.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +void DensityPriorBoxCompute::Run() { + auto& param = *param_.get_mutable(); + // required inputs + auto* input = param.input; // 4D tensor NCHW + auto* image = param.image; // 4D tensor NCHW + // outputs + auto* boxes = param.boxes; // [H, W, num_priors, 4] + auto* vars = param.variances; // [H, W, num_priors, 4] + // required attributes + bool clip = param.clip; + std::vector variances = param.variances_; + std::vector fixed_sizes = param.fixed_sizes; + std::vector fixed_ratios = param.fixed_ratios; + std::vector densities = param.density_sizes; + // optional attributes + float step_w = param.step_w; + float step_h = param.step_h; + float offset = param.offset; + + auto img_width = image->dims()[3]; + auto img_height = image->dims()[2]; + + auto feature_width = input->dims()[3]; + auto feature_height = input->dims()[2]; + + float step_width, step_height; + if (step_w == 0 || step_h == 0) { + step_width = static_cast(img_width) / feature_width; + step_height = static_cast(img_height) / feature_height; + } else { + step_width = step_w; + step_height = step_h; + } + int num_priors = 0; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for reduction(+ : num_priors) +#endif + for (size_t i = 0; i < densities.size(); ++i) { + num_priors += (fixed_ratios.size()) * (pow(densities[i], 2)); + } + + boxes->Resize({feature_height, feature_width, num_priors, 4}); + vars->Resize({feature_height, feature_width, num_priors, 4}); + auto* boxes_data = boxes->mutable_data(); + auto* vars_data = vars->mutable_data(); + + const float* input_data = input->data(); + const float* image_data = image->data(); + + lite::x86::math::density_prior_box(img_width, + img_height, + feature_width, + feature_height, + input_data, + image_data, + clip, + variances, + fixed_sizes, + fixed_ratios, + densities, + step_width, + step_height, + offset, + num_priors, + boxes_data, + vars_data); +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(density_prior_box, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::DensityPriorBoxCompute, + def) + .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("Image", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/density_prior_box_compute.h b/lite/kernels/x86/density_prior_box_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..715f0aa99a1e54d79b5bbde6562f4860066f5889 --- /dev/null +++ b/lite/kernels/x86/density_prior_box_compute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +class DensityPriorBoxCompute + : public KernelLite { + public: + using param_t = operators::DensityPriorBoxParam; + + void Run() override; + + virtual ~DensityPriorBoxCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/tests/kernels/box_coder_compute_test.cc b/lite/tests/kernels/box_coder_compute_test.cc index 9a833db31db7a6a53a4d29ed208b67e5dc77af12..f59b9dd34f761294c8350df0346f48e52130d2c0 100644 --- a/lite/tests/kernels/box_coder_compute_test.cc +++ b/lite/tests/kernels/box_coder_compute_test.cc @@ -195,6 +195,7 @@ void test_box_coder(Place place) { TEST(BoxCoder, precision) { #ifdef LITE_WITH_X86 Place place(TARGET(kX86)); + test_box_coder(place); #endif #ifdef LITE_WITH_ARM Place place(TARGET(kARM)); diff --git a/lite/tests/kernels/prior_box_compute_test.cc b/lite/tests/kernels/prior_box_compute_test.cc index 73fd612c3a03c0a15ddaf3ce6c08ff0ed1a5a95b..121ed8eefe64596fafe0f32861a1c00a5f652995 100644 --- a/lite/tests/kernels/prior_box_compute_test.cc +++ b/lite/tests/kernels/prior_box_compute_test.cc @@ -740,6 +740,7 @@ TEST(PriorBox, precision) { TEST(DensityPriorBox, precision) { #ifdef LITE_WITH_X86 Place place(TARGET(kX86)); + test_density_prior_box(place); #endif #ifdef LITE_WITH_ARM Place place(TARGET(kARM));