未验证 提交 0108c64e 编写于 作者: Q Qi Li 提交者: GitHub

[X86] add box_coder and density_prior_box kernel and fix compile on MAC, test=develop (#4353)

* [X86] add box_coder kernel and fix compile on MAC, test=develop

* [DOC] remove debug info, test=develop

* [X86] add new op of density_prior_box, test=develop
上级 678a7c85
......@@ -96,7 +96,7 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config.subgraph_model_cache_dir());
#endif
#if (defined LITE_WITH_X86) && (defined PADDLE_WITH_MKLML) && \
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL)
!(defined LITE_ON_MODEL_OPTIMIZE_TOOL) && !defined(__APPLE__)
int num_threads = config.x86_math_library_num_threads();
int real_num_threads = num_threads > 1 ? num_threads : 1;
paddle::lite::x86::MKL_Set_Num_Threads(real_num_threads);
......
......@@ -61,3 +61,5 @@ math_library(search_fc DEPS blas dynload_mklml)
# cc_test(beam_search_test SRCS beam_search_test.cc DEPS beam_search)
# cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
# cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
math_library(box_coder DEPS math_function)
math_library(prior_box DEPS math_function)
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/x86/math/box_coder.h"
#include <string>
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
void encode_center_size(const int64_t row, // N
const int64_t col, // M
const int64_t len, // 4
const float* target_box_data,
const float* prior_box_data,
const float* prior_box_var_data,
const bool normalized,
const std::vector<float> variance,
float* output) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
size_t offset = i * col * len + j * len;
float prior_box_width = prior_box_data[j * len + 2] -
prior_box_data[j * len] + (normalized == false);
float prior_box_height = prior_box_data[j * len + 3] -
prior_box_data[j * len + 1] +
(normalized == false);
float prior_box_center_x = prior_box_data[j * len] + prior_box_width / 2;
float prior_box_center_y =
prior_box_data[j * len + 1] + prior_box_height / 2;
float target_box_center_x =
(target_box_data[i * len + 2] + target_box_data[i * len]) / 2;
float target_box_center_y =
(target_box_data[i * len + 3] + target_box_data[i * len + 1]) / 2;
float target_box_width = target_box_data[i * len + 2] -
target_box_data[i * len] + (normalized == false);
float target_box_height = target_box_data[i * len + 3] -
target_box_data[i * len + 1] +
(normalized == false);
output[offset] =
(target_box_center_x - prior_box_center_x) / prior_box_width;
output[offset + 1] =
(target_box_center_y - prior_box_center_y) / prior_box_height;
output[offset + 2] =
std::log(std::fabs(target_box_width / prior_box_width));
output[offset + 3] =
std::log(std::fabs(target_box_height / prior_box_height));
}
}
if (prior_box_var_data) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
for (int64_t k = 0; k < len; ++k) {
size_t offset = i * col * len + j * len;
int prior_var_offset = j * len;
output[offset + k] /= prior_box_var_data[prior_var_offset + k];
}
}
}
} else if (!(variance.empty())) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
for (int64_t k = 0; k < len; ++k) {
size_t offset = i * col * len + j * len;
output[offset + k] /= variance[k];
}
}
}
}
}
void decode_center_size(const int axis,
const int var_size,
const int64_t row,
const int64_t col,
const int64_t len,
const float* target_box_data,
const float* prior_box_data,
const float* prior_box_var_data,
const bool normalized,
const std::vector<float> variance,
float* output) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for (int64_t i = 0; i < row; ++i) {
for (int64_t j = 0; j < col; ++j) {
float var_data[4] = {1., 1., 1., 1.};
float* var_ptr = var_data;
size_t offset = i * col * len + j * len;
int prior_box_offset = axis == 0 ? j * len : i * len;
float prior_box_width = prior_box_data[prior_box_offset + 2] -
prior_box_data[prior_box_offset] +
(normalized == false);
float prior_box_height = prior_box_data[prior_box_offset + 3] -
prior_box_data[prior_box_offset + 1] +
(normalized == false);
float prior_box_center_x =
prior_box_data[prior_box_offset] + prior_box_width / 2;
float prior_box_center_y =
prior_box_data[prior_box_offset + 1] + prior_box_height / 2;
float target_box_center_x = 0, target_box_center_y = 0;
float target_box_width = 0, target_box_height = 0;
int prior_var_offset = axis == 0 ? j * len : i * len;
if (var_size == 2) {
std::memcpy(
var_ptr, prior_box_var_data + prior_var_offset, 4 * sizeof(float));
} else if (var_size == 1) {
var_ptr = const_cast<float*>(variance.data());
}
float box_var_x = *var_ptr;
float box_var_y = *(var_ptr + 1);
float box_var_w = *(var_ptr + 2);
float box_var_h = *(var_ptr + 3);
target_box_center_x =
box_var_x * target_box_data[offset] * prior_box_width +
prior_box_center_x;
target_box_center_y =
box_var_y * target_box_data[offset + 1] * prior_box_height +
prior_box_center_y;
target_box_width =
std::exp(box_var_w * target_box_data[offset + 2]) * prior_box_width;
target_box_height =
std::exp(box_var_h * target_box_data[offset + 3]) * prior_box_height;
output[offset] = target_box_center_x - target_box_width / 2;
output[offset + 1] = target_box_center_y - target_box_height / 2;
output[offset + 2] =
target_box_center_x + target_box_width / 2 - (normalized == false);
output[offset + 3] =
target_box_center_y + target_box_height / 2 - (normalized == false);
}
}
}
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/x86/math/math_function.h"
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
void encode_center_size(const int64_t row,
const int64_t col,
const int64_t len,
const float* target_box_data,
const float* prior_box_data,
const float* prior_box_var_data,
const bool normalized,
const std::vector<float> variance,
float* output);
void decode_center_size(const int axis,
const int var_size,
const int64_t row,
const int64_t col,
const int64_t len,
const float* target_box_data,
const float* prior_box_data,
const float* prior_box_var_data,
const bool normalized,
const std::vector<float> variance,
float* output);
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "lite/backends/x86/math/prior_box.h"
#include <algorithm>
#include <string>
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
void density_prior_box(const int64_t img_width,
const int64_t img_height,
const int64_t feature_width,
const int64_t feature_height,
const float* input_data,
const float* image_data,
const bool clip,
const std::vector<float> variances,
const std::vector<float> fixed_sizes,
const std::vector<float> fixed_ratios,
const std::vector<int> densities,
const float step_width,
const float step_height,
const float offset,
const int num_priors,
float* boxes_data,
float* vars_data) {
int step_average = static_cast<int>((step_width + step_height) * 0.5);
std::vector<float> sqrt_fixed_ratios;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (size_t i = 0; i < fixed_ratios.size(); i++) {
sqrt_fixed_ratios.push_back(sqrt(fixed_ratios[i]));
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for (int64_t h = 0; h < feature_height; ++h) {
for (int64_t w = 0; w < feature_width; ++w) {
float center_x = (w + offset) * step_width;
float center_y = (h + offset) * step_height;
int64_t offset = (h * feature_width + w) * num_priors * 4;
// Generate density prior boxes with fixed sizes.
for (size_t s = 0; s < fixed_sizes.size(); ++s) {
auto fixed_size = fixed_sizes[s];
int density = densities[s];
int shift = step_average / density;
// Generate density prior boxes with fixed ratios.
for (size_t r = 0; r < fixed_ratios.size(); ++r) {
float box_width_ratio = fixed_size * sqrt_fixed_ratios[r];
float box_height_ratio = fixed_size / sqrt_fixed_ratios[r];
float density_center_x = center_x - step_average / 2. + shift / 2.;
float density_center_y = center_y - step_average / 2. + shift / 2.;
for (int di = 0; di < density; ++di) {
for (int dj = 0; dj < density; ++dj) {
float center_x_temp = density_center_x + dj * shift;
float center_y_temp = density_center_y + di * shift;
boxes_data[offset++] = std::max(
(center_x_temp - box_width_ratio / 2.) / img_width, 0.);
boxes_data[offset++] = std::max(
(center_y_temp - box_height_ratio / 2.) / img_height, 0.);
boxes_data[offset++] = std::min(
(center_x_temp + box_width_ratio / 2.) / img_width, 1.);
boxes_data[offset++] = std::min(
(center_y_temp + box_height_ratio / 2.) / img_height, 1.);
}
}
}
}
}
}
//! clip the prior's coordinate such that it is within [0, 1]
if (clip) {
int channel_size = feature_height * feature_width * num_priors * 4;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int d = 0; d < channel_size; ++d) {
boxes_data[d] = std::min(std::max(boxes_data[d], 0.f), 1.f);
}
}
//! set the variance.
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int h = 0; h < feature_height; ++h) {
for (int w = 0; w < feature_width; ++w) {
for (int i = 0; i < num_priors; ++i) {
int idx = ((h * feature_width + w) * num_priors + i) * 4;
vars_data[idx++] = variances[0];
vars_data[idx++] = variances[1];
vars_data[idx++] = variances[2];
vars_data[idx++] = variances[3];
}
}
}
}
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/backends/x86/math/math_function.h"
namespace paddle {
namespace lite {
namespace x86 {
namespace math {
void density_prior_box(const int64_t img_width,
const int64_t img_height,
const int64_t feature_width,
const int64_t feature_height,
const float* input_data,
const float* image_data,
const bool clip,
const std::vector<float> variances,
const std::vector<float> fixed_sizes,
const std::vector<float> fixed_ratios,
const std::vector<int> densities,
const float step_width,
const float step_height,
const float offset,
const int num_priors,
float* boxes_data,
float* vars_data);
} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
......@@ -61,5 +61,4 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_quant_dequant_fuse_pass,
paddle::lite::mir::QuantDequantFusePass)
.BindTargets({TARGET(kAny)})
.BindKernel("calib");
.BindTargets({TARGET(kAny)});
......@@ -68,6 +68,8 @@ add_kernel(sequence_topk_avg_pooling_compute_x86 X86 basic SRCS sequence_topk_av
add_kernel(search_fc_compute_x86 X86 basic SRCS search_fc_compute.cc DEPS ${lite_kernel_deps} search_fc)
add_kernel(matmul_compute_x86 X86 basic SRCS matmul_compute.cc DEPS ${lite_kernel_deps} blas)
add_kernel(box_coder_compute_x86 X86 basic SRCS box_coder_compute.cc DEPS ${lite_kernel_deps} box_coder)
add_kernel(density_prior_box_compute_x86 X86 basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} prior_box)
lite_cc_test(test_conv2d_compute_x86 SRCS conv_compute_test.cc DEPS conv_compute_x86)
lite_cc_test(test_mul_compute_x86 SRCS mul_compute_test.cc DEPS mul_compute_x86)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/box_coder_compute.h"
#include <string>
#include <vector>
#include "lite/backends/x86/math/box_coder.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void BoxCoderCompute::Run() {
auto& param = *param_.get_mutable<operators::BoxCoderParam>();
// required inputs
auto* prior_box = param.prior_box; // M x 4 => M x [xmin, ymin, xmax, ymax]
auto* target_box = param.target_box; // encode_center_size => N x 4;
// decode_center_size => N x M x 4
// optional input
auto* prior_box_var = param.prior_box_var; // M x 4 or 4
// output
auto* output_box = param.proposals; // N x M x 4
// required attributes
std::string code_type = param.code_type;
bool normalized = param.box_normalized;
// optional attributes
std::vector<float> variance = param.variance;
const int axis = param.axis;
auto row = target_box->dims()[0]; // N
auto col = prior_box->dims()[0]; // M
if (code_type == "decode_center_size") { // same as target_box
col = target_box->dims()[1];
}
auto len = prior_box->dims()[1]; // 4
output_box->Resize({row, col, len}); // N x M x 4
auto* output = output_box->mutable_data<float>();
const float* target_box_data = target_box->data<float>();
const float* prior_box_data = prior_box->data<float>();
const float* prior_box_var_data =
prior_box_var ? prior_box_var->data<float>() : nullptr;
if (code_type == "encode_center_size") {
lite::x86::math::encode_center_size(row,
col,
len,
target_box_data,
prior_box_data,
prior_box_var_data,
normalized,
variance,
output);
} else if (code_type == "decode_center_size") {
int var_size = 0;
if (prior_box_var) {
var_size = 2;
} else if (!(variance.empty())) {
var_size = 1;
}
lite::x86::math::decode_center_size(axis,
var_size,
row,
col,
len,
target_box_data,
prior_box_data,
prior_box_var_data,
normalized,
variance,
output);
} else {
LOG(FATAL) << "box_coder don't support this code_type: " << code_type;
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(box_coder,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::BoxCoderCompute,
def)
.BindInput("PriorBox", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("PriorBoxVar", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("TargetBox", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("OutputBox", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
class BoxCoderCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::BoxCoderParam;
void Run() override;
virtual ~BoxCoderCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/density_prior_box_compute.h"
#include <string>
#include <vector>
#include "lite/backends/x86/math/prior_box.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void DensityPriorBoxCompute::Run() {
auto& param = *param_.get_mutable<operators::DensityPriorBoxParam>();
// required inputs
auto* input = param.input; // 4D tensor NCHW
auto* image = param.image; // 4D tensor NCHW
// outputs
auto* boxes = param.boxes; // [H, W, num_priors, 4]
auto* vars = param.variances; // [H, W, num_priors, 4]
// required attributes
bool clip = param.clip;
std::vector<float> variances = param.variances_;
std::vector<float> fixed_sizes = param.fixed_sizes;
std::vector<float> fixed_ratios = param.fixed_ratios;
std::vector<int> densities = param.density_sizes;
// optional attributes
float step_w = param.step_w;
float step_h = param.step_h;
float offset = param.offset;
auto img_width = image->dims()[3];
auto img_height = image->dims()[2];
auto feature_width = input->dims()[3];
auto feature_height = input->dims()[2];
float step_width, step_height;
if (step_w == 0 || step_h == 0) {
step_width = static_cast<float>(img_width) / feature_width;
step_height = static_cast<float>(img_height) / feature_height;
} else {
step_width = step_w;
step_height = step_h;
}
int num_priors = 0;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for reduction(+ : num_priors)
#endif
for (size_t i = 0; i < densities.size(); ++i) {
num_priors += (fixed_ratios.size()) * (pow(densities[i], 2));
}
boxes->Resize({feature_height, feature_width, num_priors, 4});
vars->Resize({feature_height, feature_width, num_priors, 4});
auto* boxes_data = boxes->mutable_data<float>();
auto* vars_data = vars->mutable_data<float>();
const float* input_data = input->data<float>();
const float* image_data = image->data<float>();
lite::x86::math::density_prior_box(img_width,
img_height,
feature_width,
feature_height,
input_data,
image_data,
clip,
variances,
fixed_sizes,
fixed_ratios,
densities,
step_width,
step_height,
offset,
num_priors,
boxes_data,
vars_data);
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(density_prior_box,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::DensityPriorBoxCompute,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Image", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Boxes", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Variances", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
class DensityPriorBoxCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::DensityPriorBoxParam;
void Run() override;
virtual ~DensityPriorBoxCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -195,6 +195,7 @@ void test_box_coder(Place place) {
TEST(BoxCoder, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
test_box_coder(place);
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
......
......@@ -740,6 +740,7 @@ TEST(PriorBox, precision) {
TEST(DensityPriorBox, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
test_density_prior_box(place);
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册