提交 6bc6bcb6 编写于 作者: W Wangzheee

[OP][HOST][KERNEL]Add matrix_nms operator, and add host kernel for matrix_nms. test=develop

上级 a17d7be3
# 支持算子 # 支持算子
当前Paddle-Lite共计支持算子204个,其中基础算子78个,附加算子126个。 当前Paddle-Lite共计支持算子204个,其中基础算子78个,附加算子127个。
### 基础算子 ### 基础算子
...@@ -90,7 +90,7 @@ ...@@ -90,7 +90,7 @@
### 附加算子 ### 附加算子
附加算子共计126个,需要在编译时打开`--build_extra=ON`开关才会编译,具体请参考[参数详情](../source_compile/library) 附加算子共计127个,需要在编译时打开`--build_extra=ON`开关才会编译,具体请参考[参数详情](../source_compile/library)
| OP Name | Host | X86 | CUDA | ARM | OpenCL | FPGA | 华为NPU | 百度XPU | 瑞芯微NPU | 联发科APU | | OP Name | Host | X86 | CUDA | ARM | OpenCL | FPGA | 华为NPU | 百度XPU | 瑞芯微NPU | 联发科APU |
|-:|-|-|-|-|-|-|-|-|-|-| |-:|-|-|-|-|-|-|-|-|-|-|
...@@ -220,3 +220,4 @@ ...@@ -220,3 +220,4 @@
| __xpu__resnet_cbam |   |   |   |   |   |   |   | Y |   |   | | __xpu__resnet_cbam |   |   |   |   |   |   |   | Y |   |   |
| __xpu__resnet50 |   |   |   |   |   |   |   | Y |   |   | | __xpu__resnet50 |   |   |   |   |   |   |   | Y |   |   |
| __xpu__sfa_head |   |   |   |   |   |   |   | Y |   |   | | __xpu__sfa_head |   |   |   |   |   |   |   | Y |   |   |
| matrix_nms | Y |   |   |   |   |   |   |   |   |   |
\ No newline at end of file
...@@ -28,6 +28,7 @@ add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute. ...@@ -28,6 +28,7 @@ add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute.
add_kernel(pixel_shuffle_compute_host Host extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps}) add_kernel(pixel_shuffle_compute_host Host extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps})
add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps}) add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps})
add_kernel(uniform_random_compute_host Host extra SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps}) add_kernel(uniform_random_compute_host Host extra SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps})
add_kernel(matrix_nms_compute_host Host extra SRCS matrix_nms_compute.cc DEPS ${lite_kernel_deps})
if(LITE_BUILD_EXTRA AND LITE_WITH_x86) if(LITE_BUILD_EXTRA AND LITE_WITH_x86)
lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host) lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/matrix_nms_compute.h"
#include <map>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
template <class T>
static T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
T PolyIoU(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) {
LOG(FATAL) << "PolyIoU not implement.";
return *box1;
}
template <typename T, bool gaussian>
struct decay_score;
template <typename T>
struct decay_score<T, true> {
T operator()(T iou, T max_iou, T sigma) {
return std::exp((max_iou * max_iou - iou * iou) * sigma);
}
};
template <typename T>
struct decay_score<T, false> {
T operator()(T iou, T max_iou, T sigma) {
return (1. - iou) / (1. - max_iou);
}
};
template <typename T, bool gaussian>
void NMSMatrix(const Tensor& bbox,
const Tensor& scores,
const T score_threshold,
const T post_threshold,
const float sigma,
const int64_t top_k,
const bool normalized,
std::vector<int>* selected_indices,
std::vector<T>* decayed_scores) {
int64_t num_boxes = bbox.dims()[0];
int64_t box_size = bbox.dims()[1];
auto score_ptr = scores.data<T>();
auto bbox_ptr = bbox.data<T>();
std::vector<int32_t> perm(num_boxes);
std::iota(perm.begin(), perm.end(), 0);
auto end = std::remove_if(
perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) {
return score_ptr[idx] <= score_threshold;
});
auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) {
return score_ptr[lhs] > score_ptr[rhs];
};
int64_t num_pre = std::distance(perm.begin(), end);
if (num_pre <= 0) {
return;
}
if (top_k > -1 && num_pre > top_k) {
num_pre = top_k;
}
std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn);
std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1);
std::vector<T> iou_max(num_pre);
iou_max[0] = 0.;
for (int64_t i = 1; i < num_pre; i++) {
T max_iou = 0.;
auto idx_a = perm[i];
for (int64_t j = 0; j < i; j++) {
auto idx_b = perm[j];
auto iou = JaccardOverlap<T>(
bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized);
max_iou = std::max(max_iou, iou);
iou_matrix[i * (i - 1) / 2 + j] = iou;
}
iou_max[i] = max_iou;
}
if (score_ptr[perm[0]] > post_threshold) {
selected_indices->push_back(perm[0]);
decayed_scores->push_back(score_ptr[perm[0]]);
}
decay_score<T, gaussian> decay_fn;
for (int64_t i = 1; i < num_pre; i++) {
T min_decay = 1.;
for (int64_t j = 0; j < i; j++) {
auto max_iou = iou_max[j];
auto iou = iou_matrix[i * (i - 1) / 2 + j];
auto decay = decay_fn(iou, max_iou, sigma);
min_decay = std::min(min_decay, decay);
}
auto ds = min_decay * score_ptr[perm[i]];
if (ds <= post_threshold) continue;
selected_indices->push_back(perm[i]);
decayed_scores->push_back(ds);
}
}
template <typename T>
size_t MultiClassMatrixNMS(const Tensor& scores,
const Tensor& bboxes,
std::vector<T>* out,
std::vector<int>* indices,
int start,
int64_t background_label,
int64_t nms_top_k,
int64_t keep_top_k,
bool normalized,
T score_threshold,
T post_threshold,
bool use_gaussian,
float gaussian_sigma) {
std::vector<int> all_indices;
std::vector<T> all_scores;
std::vector<T> all_classes;
all_indices.reserve(scores.numel());
all_scores.reserve(scores.numel());
all_classes.reserve(scores.numel());
size_t num_det = 0;
auto class_num = scores.dims()[0];
Tensor score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
score_slice = scores.Slice<float>(c, c + 1);
if (use_gaussian) {
NMSMatrix<T, true>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
} else {
NMSMatrix<T, false>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
}
for (size_t i = 0; i < all_indices.size() - num_det; i++) {
all_classes.push_back(static_cast<T>(c));
}
num_det = all_indices.size();
}
if (num_det <= 0) {
return num_det;
}
if (keep_top_k > -1) {
auto k = static_cast<size_t>(keep_top_k);
if (num_det > k) num_det = k;
}
std::vector<int32_t> perm(all_indices.size());
std::iota(perm.begin(), perm.end(), 0);
std::partial_sort(perm.begin(),
perm.begin() + num_det,
perm.end(),
[&all_scores](int lhs, int rhs) {
return all_scores[lhs] > all_scores[rhs];
});
for (size_t i = 0; i < num_det; i++) {
auto p = perm[i];
auto idx = all_indices[p];
auto cls = all_classes[p];
auto score = all_scores[p];
auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1];
(*indices).push_back(start + idx);
(*out).push_back(cls);
(*out).push_back(score);
for (int j = 0; j < bboxes.dims()[1]; j++) {
(*out).push_back(bbox[j]);
}
}
return num_det;
}
void MatrixNmsCompute::Run() {
auto& param = Param<operators::MatrixNmsParam>();
auto* boxes = param.bboxes;
auto* scores = param.scores;
auto* outs = param.out;
auto* index = param.index;
auto background_label = param.background_label;
auto nms_top_k = param.nms_top_k;
auto keep_top_k = param.keep_top_k;
auto normalized = param.normalized;
auto score_threshold = param.score_threshold;
auto post_threshold = param.post_threshold;
auto use_gaussian = param.use_gaussian;
auto gaussian_sigma = param.gaussian_sigma;
auto score_dims = scores->dims();
auto batch_size = score_dims[0];
auto num_boxes = score_dims[2];
auto box_dim = boxes->dims()[2];
auto out_dim = box_dim + 2;
Tensor boxes_slice, scores_slice;
size_t num_out = 0;
std::vector<size_t> offsets = {0};
std::vector<float> detections;
std::vector<int> indices;
detections.reserve(out_dim * num_boxes * batch_size);
indices.reserve(num_boxes * batch_size);
for (int i = 0; i < batch_size; ++i) {
scores_slice = scores->Slice<float>(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice = boxes->Slice<float>(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
int start = i * score_dims[2];
num_out = MultiClassMatrixNMS(scores_slice,
boxes_slice,
&detections,
&indices,
start,
background_label,
nms_top_k,
keep_top_k,
normalized,
score_threshold,
post_threshold,
use_gaussian,
gaussian_sigma);
offsets.push_back(offsets.back() + num_out);
}
size_t num_kept = offsets.back();
if (num_kept == 0) {
outs->Resize({0, out_dim});
index->Resize({0, 1});
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
index->Resize({static_cast<int64_t>(num_kept), 1});
std::copy(
detections.begin(), detections.end(), outs->mutable_data<float>());
std::copy(indices.begin(), indices.end(), index->mutable_data<int>());
}
LoD lod;
lod.emplace_back(offsets);
outs->set_lod(lod);
index->set_lod(lod);
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(matrix_nms,
kHost,
kFloat,
kNCHW,
paddle::lite::kernels::host::MatrixNmsCompute,
def)
.BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))})
.BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))})
.BindOutput("Index",
{LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class MatrixNmsCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~MatrixNmsCompute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${ ...@@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${
add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS})
add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS})
add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS}) add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS})
add_operator(matrix_nms_op_lite extra SRCS matrix_nms_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/matrix_nms_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool MatrixNmsOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.bboxes);
CHECK_OR_FALSE(param_.scores);
CHECK_OR_FALSE(param_.out);
auto box_dims = param_.bboxes->dims();
auto score_dims = param_.scores->dims();
auto score_size = score_dims.size();
CHECK_OR_FALSE(score_size == 3);
CHECK_OR_FALSE(box_dims.size() == 3);
CHECK_OR_FALSE(box_dims[2] == 4);
CHECK_OR_FALSE(box_dims[1] == score_dims[2]);
return true;
}
bool MatrixNmsOpLite::InferShapeImpl() const {
// InferShape is useless for matrix_nms
// out's dim is not sure before the end of calculation
return true;
}
bool MatrixNmsOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) {
auto bboxes_name = opdesc.Input("BBoxes").front();
auto scores_name = opdesc.Input("Scores").front();
auto out_name = opdesc.Output("Out").front();
auto index_name = opdesc.Output("Index").front();
param_.bboxes = GetVar<lite::Tensor>(scope, bboxes_name);
param_.scores = GetVar<lite::Tensor>(scope, scores_name);
param_.out = GetMutableVar<lite::Tensor>(scope, out_name);
param_.index = GetMutableVar<lite::Tensor>(scope, index_name);
param_.background_label = opdesc.GetAttr<int>("background_label");
param_.score_threshold = opdesc.GetAttr<float>("score_threshold");
param_.post_threshold = opdesc.GetAttr<float>("post_threshold");
param_.nms_top_k = opdesc.GetAttr<int>("nms_top_k");
param_.keep_top_k = opdesc.GetAttr<int>("keep_top_k");
param_.normalized = opdesc.GetAttr<bool>("normalized");
param_.use_gaussian = opdesc.GetAttr<bool>("use_gaussian");
param_.gaussian_sigma = opdesc.GetAttr<float>("gaussian_sigma");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(matrix_nms, paddle::lite::operators::MatrixNmsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class MatrixNmsOpLite : public OpLite {
public:
MatrixNmsOpLite() {}
explicit MatrixNmsOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "matrix_nms"; }
private:
mutable MatrixNmsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -880,6 +880,22 @@ struct MulticlassNmsParam : ParamBase { ...@@ -880,6 +880,22 @@ struct MulticlassNmsParam : ParamBase {
bool normalized{true}; bool normalized{true};
}; };
/// ----------------------- matrix_nms operators ----------------------
struct MatrixNmsParam : ParamBase {
const lite::Tensor* bboxes{};
const lite::Tensor* scores{};
lite::Tensor* out{};
lite::Tensor* index{};
int background_label{0};
float score_threshold{};
float post_threshold{0.0f};
int nms_top_k{};
int keep_top_k;
bool normalized{true};
bool use_gaussian{false};
float gaussian_sigma{2.0f};
};
/// ----------------------- priorbox operators ---------------------- /// ----------------------- priorbox operators ----------------------
struct PriorBoxParam : ParamBase { struct PriorBoxParam : ParamBase {
lite::Tensor* input{}; lite::Tensor* input{};
......
...@@ -34,6 +34,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT ...@@ -34,6 +34,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matrix_nms_compute SRCS matrix_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include <string>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
template <class T>
static T BBoxArea(const T* box, const bool normalized) {
if (box[2] < box[0] || box[3] < box[1]) {
// If coordinate values are is invalid
// (e.g. xmax < xmin or ymax < ymin), return 0.
return static_cast<T>(0.);
} else {
const T w = box[2] - box[0];
const T h = box[3] - box[1];
if (normalized) {
return w * h;
} else {
// If coordinate values are not within range [0, 1].
return (w + 1) * (h + 1);
}
}
}
template <class T>
static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) {
if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
box2[3] < box1[1]) {
return static_cast<T>(0.);
} else {
const T inter_xmin = std::max(box1[0], box2[0]);
const T inter_ymin = std::max(box1[1], box2[1]);
const T inter_xmax = std::min(box1[2], box2[2]);
const T inter_ymax = std::min(box1[3], box2[3]);
T norm = normalized ? static_cast<T>(0.) : static_cast<T>(1.);
T inter_w = inter_xmax - inter_xmin + norm;
T inter_h = inter_ymax - inter_ymin + norm;
const T inter_area = inter_w * inter_h;
const T bbox1_area = BBoxArea<T>(box1, normalized);
const T bbox2_area = BBoxArea<T>(box2, normalized);
return inter_area / (bbox1_area + bbox2_area - inter_area);
}
}
template <class T>
T PolyIoU(const T* box1,
const T* box2,
const size_t box_size,
const bool normalized) {
LOG(FATAL) << "PolyIoU not implement.";
return *box1;
}
template <typename T, bool gaussian>
struct decay_score;
template <typename T>
struct decay_score<T, true> {
T operator()(T iou, T max_iou, T sigma) {
return std::exp((max_iou * max_iou - iou * iou) * sigma);
}
};
template <typename T>
struct decay_score<T, false> {
T operator()(T iou, T max_iou, T sigma) {
return (1. - iou) / (1. - max_iou);
}
};
template <typename T, bool gaussian>
void NMSMatrix(const Tensor& bbox,
const Tensor& scores,
const T score_threshold,
const T post_threshold,
const float sigma,
const int64_t top_k,
const bool normalized,
std::vector<int>* selected_indices,
std::vector<T>* decayed_scores) {
int64_t num_boxes = bbox.dims()[0];
int64_t box_size = bbox.dims()[1];
auto score_ptr = scores.data<T>();
auto bbox_ptr = bbox.data<T>();
std::vector<int32_t> perm(num_boxes);
std::iota(perm.begin(), perm.end(), 0);
auto end = std::remove_if(
perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) {
return score_ptr[idx] <= score_threshold;
});
auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) {
return score_ptr[lhs] > score_ptr[rhs];
};
int64_t num_pre = std::distance(perm.begin(), end);
if (num_pre <= 0) {
return;
}
if (top_k > -1 && num_pre > top_k) {
num_pre = top_k;
}
std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn);
std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1);
std::vector<T> iou_max(num_pre);
iou_max[0] = 0.;
for (int64_t i = 1; i < num_pre; i++) {
T max_iou = 0.;
auto idx_a = perm[i];
for (int64_t j = 0; j < i; j++) {
auto idx_b = perm[j];
auto iou = JaccardOverlap<T>(
bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized);
max_iou = std::max(max_iou, iou);
iou_matrix[i * (i - 1) / 2 + j] = iou;
}
iou_max[i] = max_iou;
}
if (score_ptr[perm[0]] > post_threshold) {
selected_indices->push_back(perm[0]);
decayed_scores->push_back(score_ptr[perm[0]]);
}
decay_score<T, gaussian> decay_fn;
for (int64_t i = 1; i < num_pre; i++) {
T min_decay = 1.;
for (int64_t j = 0; j < i; j++) {
auto max_iou = iou_max[j];
auto iou = iou_matrix[i * (i - 1) / 2 + j];
auto decay = decay_fn(iou, max_iou, sigma);
min_decay = std::min(min_decay, decay);
}
auto ds = min_decay * score_ptr[perm[i]];
if (ds <= post_threshold) continue;
selected_indices->push_back(perm[i]);
decayed_scores->push_back(ds);
}
}
template <typename T>
size_t MultiClassMatrixNMS(const Tensor& scores,
const Tensor& bboxes,
std::vector<T>* out,
std::vector<int>* indices,
int start,
int64_t background_label,
int64_t nms_top_k,
int64_t keep_top_k,
bool normalized,
T score_threshold,
T post_threshold,
bool use_gaussian,
float gaussian_sigma) {
std::vector<int> all_indices;
std::vector<T> all_scores;
std::vector<T> all_classes;
all_indices.reserve(scores.numel());
all_scores.reserve(scores.numel());
all_classes.reserve(scores.numel());
size_t num_det = 0;
auto class_num = scores.dims()[0];
Tensor score_slice;
for (int64_t c = 0; c < class_num; ++c) {
if (c == background_label) continue;
score_slice = scores.Slice<float>(c, c + 1);
if (use_gaussian) {
NMSMatrix<T, true>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
} else {
NMSMatrix<T, false>(bboxes,
score_slice,
score_threshold,
post_threshold,
gaussian_sigma,
nms_top_k,
normalized,
&all_indices,
&all_scores);
}
for (size_t i = 0; i < all_indices.size() - num_det; i++) {
all_classes.push_back(static_cast<T>(c));
}
num_det = all_indices.size();
}
if (num_det <= 0) {
return num_det;
}
if (keep_top_k > -1) {
auto k = static_cast<size_t>(keep_top_k);
if (num_det > k) num_det = k;
}
std::vector<int32_t> perm(all_indices.size());
std::iota(perm.begin(), perm.end(), 0);
std::partial_sort(perm.begin(),
perm.begin() + num_det,
perm.end(),
[&all_scores](int lhs, int rhs) {
return all_scores[lhs] > all_scores[rhs];
});
for (size_t i = 0; i < num_det; i++) {
auto p = perm[i];
auto idx = all_indices[p];
auto cls = all_classes[p];
auto score = all_scores[p];
auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1];
(*indices).push_back(start + idx);
(*out).push_back(cls);
(*out).push_back(score);
for (int j = 0; j < bboxes.dims()[1]; j++) {
(*out).push_back(bbox[j]);
}
}
return num_det;
}
class MatrixNmsComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string type_ = "matrix_nms";
std::string bboxes_ = "bboxes";
std::string scores_ = "scores";
std::string out_ = "out";
std::string index_ = "index";
DDim bboxes_dims_{};
DDim scores_dims_{};
int background_label_{0};
float score_threshold_{0.01f};
float post_threshold_{0.0f};
int nms_top_k_{1};
int keep_top_k_{2};
bool normalized_{false};
bool use_gaussian_{true};
float gaussian_sigma_{2.0f};
public:
MatrixNmsComputeTester(const Place& place,
const std::string& alias,
DDim bboxes_dims,
DDim scores_dims,
int background_label = 0,
float score_threshold = 0.01f,
float post_threshold = 0.0f,
int nms_top_k = 1,
int keep_top_k = 2,
bool normalized = false,
bool use_gaussian = true,
float gaussian_sigma = 2.0f)
: TestCase(place, alias),
bboxes_dims_(bboxes_dims),
scores_dims_(scores_dims),
background_label_(background_label),
score_threshold_(score_threshold),
post_threshold_(post_threshold),
nms_top_k_(nms_top_k),
keep_top_k_(keep_top_k),
normalized_(normalized),
use_gaussian_(use_gaussian),
gaussian_sigma_(gaussian_sigma) {}
void RunBaseline(Scope* scope) override {
auto* boxes = scope->FindTensor(bboxes_);
auto* scores = scope->FindTensor(scores_);
auto* outs = scope->NewTensor(out_);
auto* index = scope->NewTensor(index_);
CHECK(outs);
outs->set_precision(PRECISION(kFloat));
CHECK(index);
index->set_precision(PRECISION(kInt32));
auto score_dims = scores->dims();
auto batch_size = score_dims[0];
auto num_boxes = score_dims[2];
auto box_dim = boxes->dims()[2];
auto out_dim = box_dim + 2;
Tensor boxes_slice, scores_slice;
size_t num_out = 0;
std::vector<size_t> offsets = {0};
std::vector<float> detections;
std::vector<int> indices;
detections.reserve(out_dim * num_boxes * batch_size);
indices.reserve(num_boxes * batch_size);
for (int i = 0; i < batch_size; ++i) {
scores_slice = scores->Slice<float>(i, i + 1);
scores_slice.Resize({score_dims[1], score_dims[2]});
boxes_slice = boxes->Slice<float>(i, i + 1);
boxes_slice.Resize({score_dims[2], box_dim});
int start = i * score_dims[2];
num_out = MultiClassMatrixNMS(scores_slice,
boxes_slice,
&detections,
&indices,
start,
background_label_,
nms_top_k_,
keep_top_k_,
normalized_,
score_threshold_,
post_threshold_,
use_gaussian_,
gaussian_sigma_);
offsets.push_back(offsets.back() + num_out);
}
size_t num_kept = offsets.back();
if (num_kept == 0) {
outs->Resize({0, out_dim});
index->Resize({0, 1});
} else {
outs->Resize({static_cast<int64_t>(num_kept), out_dim});
index->Resize({static_cast<int64_t>(num_kept), 1});
std::copy(
detections.begin(), detections.end(), outs->mutable_data<float>());
std::copy(indices.begin(), indices.end(), index->mutable_data<int>());
}
LoD lod;
lod.emplace_back(offsets);
outs->set_lod(lod);
index->set_lod(lod);
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType(type_);
op_desc->SetInput("BBoxes", {bboxes_});
op_desc->SetInput("Scores", {scores_});
op_desc->SetOutput("Out", {out_});
op_desc->SetOutput("Index", {index_});
op_desc->SetAttr("background_label", background_label_);
op_desc->SetAttr("score_threshold", score_threshold_);
op_desc->SetAttr("post_threshold", post_threshold_);
op_desc->SetAttr("nms_top_k", nms_top_k_);
op_desc->SetAttr("keep_top_k", keep_top_k_);
op_desc->SetAttr("normalized", normalized_);
op_desc->SetAttr("use_gaussian", use_gaussian_);
op_desc->SetAttr("gaussian_sigma", gaussian_sigma_);
}
void PrepareData() override {
std::vector<float> bboxes(bboxes_dims_.production());
for (int i = 0; i < bboxes_dims_.production(); ++i) {
bboxes[i] = i * 1. / bboxes_dims_.production();
}
SetCommonTensor(bboxes_, bboxes_dims_, bboxes.data());
std::vector<float> scores(scores_dims_.production());
for (int i = 0; i < scores_dims_.production(); ++i) {
scores[i] = i * 1. / scores_dims_.production();
}
SetCommonTensor(scores_, scores_dims_, scores.data());
}
};
void TestMatrixNms(Place place, float abs_error) {
int N = 3;
int M = 2500;
for (int class_num : {2, 4, 10}) {
std::vector<int64_t> bbox_shape{N, M, 4};
std::vector<int64_t> score_shape{N, class_num, M};
std::unique_ptr<arena::TestCase> tester(new MatrixNmsComputeTester(
place, "def", DDim(bbox_shape), DDim(score_shape)));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
TEST(matrix_nms, precision) {
float abs_error = 2e-5;
Place place;
#if defined(LITE_WITH_ARM)
place = TARGET(kHost);
#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL)
place = TARGET(kXPU);
#else
return;
#endif
TestMatrixNms(place, abs_error);
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册