From 6bc6bcb629c07507c48b20d2f890d99b24e43c10 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 10 Sep 2020 15:51:12 +0000 Subject: [PATCH] [OP][HOST][KERNEL]Add matrix_nms operator, and add host kernel for matrix_nms. test=develop --- docs/introduction/support_operation_list.md | 5 +- lite/kernels/host/CMakeLists.txt | 1 + lite/kernels/host/matrix_nms_compute.cc | 338 ++++++++++++++ lite/kernels/host/matrix_nms_compute.h | 35 ++ lite/operators/CMakeLists.txt | 1 + lite/operators/matrix_nms_op.cc | 69 +++ lite/operators/matrix_nms_op.h | 45 ++ lite/operators/op_params.h | 16 + lite/tests/kernels/CMakeLists.txt | 1 + lite/tests/kernels/matrix_nms_compute_test.cc | 420 ++++++++++++++++++ 10 files changed, 929 insertions(+), 2 deletions(-) create mode 100644 lite/kernels/host/matrix_nms_compute.cc create mode 100644 lite/kernels/host/matrix_nms_compute.h create mode 100644 lite/operators/matrix_nms_op.cc create mode 100644 lite/operators/matrix_nms_op.h create mode 100644 lite/tests/kernels/matrix_nms_compute_test.cc diff --git a/docs/introduction/support_operation_list.md b/docs/introduction/support_operation_list.md index 1a0ba4e3c6..793a47559d 100644 --- a/docs/introduction/support_operation_list.md +++ b/docs/introduction/support_operation_list.md @@ -1,6 +1,6 @@ # 支持算子 -当前Paddle-Lite共计支持算子204个,其中基础算子78个,附加算子126个。 +当前Paddle-Lite共计支持算子204个,其中基础算子78个,附加算子127个。 ### 基础算子 @@ -90,7 +90,7 @@ ### 附加算子 -附加算子共计126个,需要在编译时打开`--build_extra=ON`开关才会编译,具体请参考[参数详情](../source_compile/library)。 +附加算子共计127个,需要在编译时打开`--build_extra=ON`开关才会编译,具体请参考[参数详情](../source_compile/library)。 | OP Name | Host | X86 | CUDA | ARM | OpenCL | FPGA | 华为NPU | 百度XPU | 瑞芯微NPU | 联发科APU | |-:|-|-|-|-|-|-|-|-|-|-| @@ -220,3 +220,4 @@ | __xpu__resnet_cbam |   |   |   |   |   |   |   | Y |   |   | | __xpu__resnet50 |   |   |   |   |   |   |   | Y |   |   | | __xpu__sfa_head |   |   |   |   |   |   |   | Y |   |   | +| matrix_nms | Y |   |   |   |   |   |   |   |   |   | \ No newline at end of file diff --git a/lite/kernels/host/CMakeLists.txt b/lite/kernels/host/CMakeLists.txt index 3cbc585d78..638eb0d077 100644 --- a/lite/kernels/host/CMakeLists.txt +++ b/lite/kernels/host/CMakeLists.txt @@ -28,6 +28,7 @@ add_kernel(activation_grad_compute_host Host train SRCS activation_grad_compute. add_kernel(pixel_shuffle_compute_host Host extra SRCS pixel_shuffle_compute.cc DEPS ${lite_kernel_deps}) add_kernel(one_hot_compute_host Host extra SRCS one_hot_compute.cc DEPS ${lite_kernel_deps}) add_kernel(uniform_random_compute_host Host extra SRCS uniform_random_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(matrix_nms_compute_host Host extra SRCS matrix_nms_compute.cc DEPS ${lite_kernel_deps}) if(LITE_BUILD_EXTRA AND LITE_WITH_x86) lite_cc_test(test_where_index_compute_host SRCS where_index_compute.cc DEPS where_index_compute_host) diff --git a/lite/kernels/host/matrix_nms_compute.cc b/lite/kernels/host/matrix_nms_compute.cc new file mode 100644 index 0000000000..d0f4785c67 --- /dev/null +++ b/lite/kernels/host/matrix_nms_compute.cc @@ -0,0 +1,338 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/host/matrix_nms_compute.h" +#include +#include +#include + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +template +static T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +T PolyIoU(const T* box1, + const T* box2, + const size_t box_size, + const bool normalized) { + LOG(FATAL) << "PolyIoU not implement."; + return *box1; +} + +template +struct decay_score; + +template +struct decay_score { + T operator()(T iou, T max_iou, T sigma) { + return std::exp((max_iou * max_iou - iou * iou) * sigma); + } +}; + +template +struct decay_score { + T operator()(T iou, T max_iou, T sigma) { + return (1. - iou) / (1. - max_iou); + } +}; + +template +void NMSMatrix(const Tensor& bbox, + const Tensor& scores, + const T score_threshold, + const T post_threshold, + const float sigma, + const int64_t top_k, + const bool normalized, + std::vector* selected_indices, + std::vector* decayed_scores) { + int64_t num_boxes = bbox.dims()[0]; + int64_t box_size = bbox.dims()[1]; + + auto score_ptr = scores.data(); + auto bbox_ptr = bbox.data(); + + std::vector perm(num_boxes); + std::iota(perm.begin(), perm.end(), 0); + auto end = std::remove_if( + perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) { + return score_ptr[idx] <= score_threshold; + }); + auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) { + return score_ptr[lhs] > score_ptr[rhs]; + }; + + int64_t num_pre = std::distance(perm.begin(), end); + if (num_pre <= 0) { + return; + } + if (top_k > -1 && num_pre > top_k) { + num_pre = top_k; + } + std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn); + + std::vector iou_matrix((num_pre * (num_pre - 1)) >> 1); + std::vector iou_max(num_pre); + + iou_max[0] = 0.; + for (int64_t i = 1; i < num_pre; i++) { + T max_iou = 0.; + auto idx_a = perm[i]; + for (int64_t j = 0; j < i; j++) { + auto idx_b = perm[j]; + auto iou = JaccardOverlap( + bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized); + max_iou = std::max(max_iou, iou); + iou_matrix[i * (i - 1) / 2 + j] = iou; + } + iou_max[i] = max_iou; + } + + if (score_ptr[perm[0]] > post_threshold) { + selected_indices->push_back(perm[0]); + decayed_scores->push_back(score_ptr[perm[0]]); + } + + decay_score decay_fn; + for (int64_t i = 1; i < num_pre; i++) { + T min_decay = 1.; + for (int64_t j = 0; j < i; j++) { + auto max_iou = iou_max[j]; + auto iou = iou_matrix[i * (i - 1) / 2 + j]; + auto decay = decay_fn(iou, max_iou, sigma); + min_decay = std::min(min_decay, decay); + } + auto ds = min_decay * score_ptr[perm[i]]; + if (ds <= post_threshold) continue; + selected_indices->push_back(perm[i]); + decayed_scores->push_back(ds); + } +} + +template +size_t MultiClassMatrixNMS(const Tensor& scores, + const Tensor& bboxes, + std::vector* out, + std::vector* indices, + int start, + int64_t background_label, + int64_t nms_top_k, + int64_t keep_top_k, + bool normalized, + T score_threshold, + T post_threshold, + bool use_gaussian, + float gaussian_sigma) { + std::vector all_indices; + std::vector all_scores; + std::vector all_classes; + all_indices.reserve(scores.numel()); + all_scores.reserve(scores.numel()); + all_classes.reserve(scores.numel()); + + size_t num_det = 0; + auto class_num = scores.dims()[0]; + Tensor score_slice; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + score_slice = scores.Slice(c, c + 1); + if (use_gaussian) { + NMSMatrix(bboxes, + score_slice, + score_threshold, + post_threshold, + gaussian_sigma, + nms_top_k, + normalized, + &all_indices, + &all_scores); + } else { + NMSMatrix(bboxes, + score_slice, + score_threshold, + post_threshold, + gaussian_sigma, + nms_top_k, + normalized, + &all_indices, + &all_scores); + } + for (size_t i = 0; i < all_indices.size() - num_det; i++) { + all_classes.push_back(static_cast(c)); + } + num_det = all_indices.size(); + } + + if (num_det <= 0) { + return num_det; + } + + if (keep_top_k > -1) { + auto k = static_cast(keep_top_k); + if (num_det > k) num_det = k; + } + + std::vector perm(all_indices.size()); + std::iota(perm.begin(), perm.end(), 0); + + std::partial_sort(perm.begin(), + perm.begin() + num_det, + perm.end(), + [&all_scores](int lhs, int rhs) { + return all_scores[lhs] > all_scores[rhs]; + }); + + for (size_t i = 0; i < num_det; i++) { + auto p = perm[i]; + auto idx = all_indices[p]; + auto cls = all_classes[p]; + auto score = all_scores[p]; + auto bbox = bboxes.data() + idx * bboxes.dims()[1]; + (*indices).push_back(start + idx); + (*out).push_back(cls); + (*out).push_back(score); + for (int j = 0; j < bboxes.dims()[1]; j++) { + (*out).push_back(bbox[j]); + } + } + + return num_det; +} + +void MatrixNmsCompute::Run() { + auto& param = Param(); + auto* boxes = param.bboxes; + auto* scores = param.scores; + auto* outs = param.out; + auto* index = param.index; + + auto background_label = param.background_label; + auto nms_top_k = param.nms_top_k; + auto keep_top_k = param.keep_top_k; + auto normalized = param.normalized; + auto score_threshold = param.score_threshold; + auto post_threshold = param.post_threshold; + auto use_gaussian = param.use_gaussian; + auto gaussian_sigma = param.gaussian_sigma; + + auto score_dims = scores->dims(); + auto batch_size = score_dims[0]; + auto num_boxes = score_dims[2]; + auto box_dim = boxes->dims()[2]; + auto out_dim = box_dim + 2; + + Tensor boxes_slice, scores_slice; + size_t num_out = 0; + std::vector offsets = {0}; + std::vector detections; + std::vector indices; + detections.reserve(out_dim * num_boxes * batch_size); + indices.reserve(num_boxes * batch_size); + for (int i = 0; i < batch_size; ++i) { + scores_slice = scores->Slice(i, i + 1); + scores_slice.Resize({score_dims[1], score_dims[2]}); + boxes_slice = boxes->Slice(i, i + 1); + boxes_slice.Resize({score_dims[2], box_dim}); + int start = i * score_dims[2]; + num_out = MultiClassMatrixNMS(scores_slice, + boxes_slice, + &detections, + &indices, + start, + background_label, + nms_top_k, + keep_top_k, + normalized, + score_threshold, + post_threshold, + use_gaussian, + gaussian_sigma); + offsets.push_back(offsets.back() + num_out); + } + + size_t num_kept = offsets.back(); + if (num_kept == 0) { + outs->Resize({0, out_dim}); + index->Resize({0, 1}); + } else { + outs->Resize({static_cast(num_kept), out_dim}); + index->Resize({static_cast(num_kept), 1}); + std::copy( + detections.begin(), detections.end(), outs->mutable_data()); + std::copy(indices.begin(), indices.end(), index->mutable_data()); + } + + LoD lod; + lod.emplace_back(offsets); + outs->set_lod(lod); + index->set_lod(lod); +} + + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(matrix_nms, + kHost, + kFloat, + kNCHW, + paddle::lite::kernels::host::MatrixNmsCompute, + def) + .BindInput("BBoxes", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindInput("Scores", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) + .BindOutput("Index", + {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))}) + .Finalize(); diff --git a/lite/kernels/host/matrix_nms_compute.h b/lite/kernels/host/matrix_nms_compute.h new file mode 100644 index 0000000000..745a96fa90 --- /dev/null +++ b/lite/kernels/host/matrix_nms_compute.h @@ -0,0 +1,35 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace host { + +class MatrixNmsCompute : public KernelLite { + public: + void Run() override; + + virtual ~MatrixNmsCompute() = default; +}; + +} // namespace host +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/operators/CMakeLists.txt b/lite/operators/CMakeLists.txt index 02377aad49..3a5b856b20 100644 --- a/lite/operators/CMakeLists.txt +++ b/lite/operators/CMakeLists.txt @@ -121,6 +121,7 @@ add_operator(max_pool_with_index_op extra SRCS max_pool_with_index_op.cc DEPS ${ add_operator(pixel_shuffle_op extra SRCS pixel_shuffle_op.cc DEPS ${op_DEPS}) add_operator(clip_op extra SRCS clip_op.cc DEPS ${op_DEPS}) add_operator(print_op extra SRCS print_op.cc DEPS ${op_DEPS}) +add_operator(matrix_nms_op_lite extra SRCS matrix_nms_op.cc DEPS ${op_DEPS}) # for OCR specific add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS}) diff --git a/lite/operators/matrix_nms_op.cc b/lite/operators/matrix_nms_op.cc new file mode 100644 index 0000000000..9c3350be7d --- /dev/null +++ b/lite/operators/matrix_nms_op.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/operators/matrix_nms_op.h" +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool MatrixNmsOpLite::CheckShape() const { + CHECK_OR_FALSE(param_.bboxes); + CHECK_OR_FALSE(param_.scores); + CHECK_OR_FALSE(param_.out); + + auto box_dims = param_.bboxes->dims(); + auto score_dims = param_.scores->dims(); + auto score_size = score_dims.size(); + + CHECK_OR_FALSE(score_size == 3); + CHECK_OR_FALSE(box_dims.size() == 3); + CHECK_OR_FALSE(box_dims[2] == 4); + CHECK_OR_FALSE(box_dims[1] == score_dims[2]); + return true; +} + +bool MatrixNmsOpLite::InferShapeImpl() const { + // InferShape is useless for matrix_nms + // out's dim is not sure before the end of calculation + return true; +} + +bool MatrixNmsOpLite::AttachImpl(const cpp::OpDesc& opdesc, + lite::Scope* scope) { + auto bboxes_name = opdesc.Input("BBoxes").front(); + auto scores_name = opdesc.Input("Scores").front(); + auto out_name = opdesc.Output("Out").front(); + auto index_name = opdesc.Output("Index").front(); + param_.bboxes = GetVar(scope, bboxes_name); + param_.scores = GetVar(scope, scores_name); + param_.out = GetMutableVar(scope, out_name); + param_.index = GetMutableVar(scope, index_name); + param_.background_label = opdesc.GetAttr("background_label"); + param_.score_threshold = opdesc.GetAttr("score_threshold"); + param_.post_threshold = opdesc.GetAttr("post_threshold"); + param_.nms_top_k = opdesc.GetAttr("nms_top_k"); + param_.keep_top_k = opdesc.GetAttr("keep_top_k"); + param_.normalized = opdesc.GetAttr("normalized"); + param_.use_gaussian = opdesc.GetAttr("use_gaussian"); + param_.gaussian_sigma = opdesc.GetAttr("gaussian_sigma"); + return true; +} + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(matrix_nms, paddle::lite::operators::MatrixNmsOpLite); diff --git a/lite/operators/matrix_nms_op.h b/lite/operators/matrix_nms_op.h new file mode 100644 index 0000000000..4999d2debd --- /dev/null +++ b/lite/operators/matrix_nms_op.h @@ -0,0 +1,45 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/op_lite.h" + +namespace paddle { +namespace lite { +namespace operators { + +class MatrixNmsOpLite : public OpLite { + public: + MatrixNmsOpLite() {} + explicit MatrixNmsOpLite(const std::string &op_type) : OpLite(op_type) {} + + bool CheckShape() const override; + + bool InferShapeImpl() const override; + + bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + std::string DebugString() const override { return "matrix_nms"; } + + private: + mutable MatrixNmsParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 98b08a6b0d..691846d0ba 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -880,6 +880,22 @@ struct MulticlassNmsParam : ParamBase { bool normalized{true}; }; +/// ----------------------- matrix_nms operators ---------------------- +struct MatrixNmsParam : ParamBase { + const lite::Tensor* bboxes{}; + const lite::Tensor* scores{}; + lite::Tensor* out{}; + lite::Tensor* index{}; + int background_label{0}; + float score_threshold{}; + float post_threshold{0.0f}; + int nms_top_k{}; + int keep_top_k; + bool normalized{true}; + bool use_gaussian{false}; + float gaussian_sigma{2.0f}; +}; + /// ----------------------- priorbox operators ---------------------- struct PriorBoxParam : ParamBase { lite::Tensor* input{}; diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index c2f0c2ba91..a2546c55ed 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -34,6 +34,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM AND NOT LIT lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_multiclass_nms_compute SRCS multiclass_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_matrix_nms_compute SRCS matrix_nms_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_batch_norm_compute SRCS batch_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_fill_constant_compute SRCS fill_constant_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${huawei_ascend_npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/matrix_nms_compute_test.cc b/lite/tests/kernels/matrix_nms_compute_test.cc new file mode 100644 index 0000000000..0731a49c1f --- /dev/null +++ b/lite/tests/kernels/matrix_nms_compute_test.cc @@ -0,0 +1,420 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/api/paddle_use_kernels.h" +#include "lite/api/paddle_use_ops.h" +#include "lite/core/arena/framework.h" +#include "lite/tests/utils/fill_data.h" + +namespace paddle { +namespace lite { + +template +static T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static T JaccardOverlap(const T* box1, const T* box2, const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +T PolyIoU(const T* box1, + const T* box2, + const size_t box_size, + const bool normalized) { + LOG(FATAL) << "PolyIoU not implement."; + return *box1; +} + +template +struct decay_score; + +template +struct decay_score { + T operator()(T iou, T max_iou, T sigma) { + return std::exp((max_iou * max_iou - iou * iou) * sigma); + } +}; + +template +struct decay_score { + T operator()(T iou, T max_iou, T sigma) { + return (1. - iou) / (1. - max_iou); + } +}; + +template +void NMSMatrix(const Tensor& bbox, + const Tensor& scores, + const T score_threshold, + const T post_threshold, + const float sigma, + const int64_t top_k, + const bool normalized, + std::vector* selected_indices, + std::vector* decayed_scores) { + int64_t num_boxes = bbox.dims()[0]; + int64_t box_size = bbox.dims()[1]; + + auto score_ptr = scores.data(); + auto bbox_ptr = bbox.data(); + + std::vector perm(num_boxes); + std::iota(perm.begin(), perm.end(), 0); + auto end = std::remove_if( + perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) { + return score_ptr[idx] <= score_threshold; + }); + auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) { + return score_ptr[lhs] > score_ptr[rhs]; + }; + + int64_t num_pre = std::distance(perm.begin(), end); + if (num_pre <= 0) { + return; + } + if (top_k > -1 && num_pre > top_k) { + num_pre = top_k; + } + std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn); + + std::vector iou_matrix((num_pre * (num_pre - 1)) >> 1); + std::vector iou_max(num_pre); + + iou_max[0] = 0.; + for (int64_t i = 1; i < num_pre; i++) { + T max_iou = 0.; + auto idx_a = perm[i]; + for (int64_t j = 0; j < i; j++) { + auto idx_b = perm[j]; + auto iou = JaccardOverlap( + bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized); + max_iou = std::max(max_iou, iou); + iou_matrix[i * (i - 1) / 2 + j] = iou; + } + iou_max[i] = max_iou; + } + + if (score_ptr[perm[0]] > post_threshold) { + selected_indices->push_back(perm[0]); + decayed_scores->push_back(score_ptr[perm[0]]); + } + + decay_score decay_fn; + for (int64_t i = 1; i < num_pre; i++) { + T min_decay = 1.; + for (int64_t j = 0; j < i; j++) { + auto max_iou = iou_max[j]; + auto iou = iou_matrix[i * (i - 1) / 2 + j]; + auto decay = decay_fn(iou, max_iou, sigma); + min_decay = std::min(min_decay, decay); + } + auto ds = min_decay * score_ptr[perm[i]]; + if (ds <= post_threshold) continue; + selected_indices->push_back(perm[i]); + decayed_scores->push_back(ds); + } +} + +template +size_t MultiClassMatrixNMS(const Tensor& scores, + const Tensor& bboxes, + std::vector* out, + std::vector* indices, + int start, + int64_t background_label, + int64_t nms_top_k, + int64_t keep_top_k, + bool normalized, + T score_threshold, + T post_threshold, + bool use_gaussian, + float gaussian_sigma) { + std::vector all_indices; + std::vector all_scores; + std::vector all_classes; + all_indices.reserve(scores.numel()); + all_scores.reserve(scores.numel()); + all_classes.reserve(scores.numel()); + + size_t num_det = 0; + auto class_num = scores.dims()[0]; + Tensor score_slice; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + score_slice = scores.Slice(c, c + 1); + if (use_gaussian) { + NMSMatrix(bboxes, + score_slice, + score_threshold, + post_threshold, + gaussian_sigma, + nms_top_k, + normalized, + &all_indices, + &all_scores); + } else { + NMSMatrix(bboxes, + score_slice, + score_threshold, + post_threshold, + gaussian_sigma, + nms_top_k, + normalized, + &all_indices, + &all_scores); + } + for (size_t i = 0; i < all_indices.size() - num_det; i++) { + all_classes.push_back(static_cast(c)); + } + num_det = all_indices.size(); + } + + if (num_det <= 0) { + return num_det; + } + + if (keep_top_k > -1) { + auto k = static_cast(keep_top_k); + if (num_det > k) num_det = k; + } + + std::vector perm(all_indices.size()); + std::iota(perm.begin(), perm.end(), 0); + + std::partial_sort(perm.begin(), + perm.begin() + num_det, + perm.end(), + [&all_scores](int lhs, int rhs) { + return all_scores[lhs] > all_scores[rhs]; + }); + + for (size_t i = 0; i < num_det; i++) { + auto p = perm[i]; + auto idx = all_indices[p]; + auto cls = all_classes[p]; + auto score = all_scores[p]; + auto bbox = bboxes.data() + idx * bboxes.dims()[1]; + (*indices).push_back(start + idx); + (*out).push_back(cls); + (*out).push_back(score); + for (int j = 0; j < bboxes.dims()[1]; j++) { + (*out).push_back(bbox[j]); + } + } + + return num_det; +} + +class MatrixNmsComputeTester : public arena::TestCase { + protected: + // common attributes for this op. + std::string type_ = "matrix_nms"; + std::string bboxes_ = "bboxes"; + std::string scores_ = "scores"; + std::string out_ = "out"; + std::string index_ = "index"; + DDim bboxes_dims_{}; + DDim scores_dims_{}; + int background_label_{0}; + float score_threshold_{0.01f}; + float post_threshold_{0.0f}; + int nms_top_k_{1}; + int keep_top_k_{2}; + bool normalized_{false}; + bool use_gaussian_{true}; + float gaussian_sigma_{2.0f}; + + public: + MatrixNmsComputeTester(const Place& place, + const std::string& alias, + DDim bboxes_dims, + DDim scores_dims, + int background_label = 0, + float score_threshold = 0.01f, + float post_threshold = 0.0f, + int nms_top_k = 1, + int keep_top_k = 2, + bool normalized = false, + bool use_gaussian = true, + float gaussian_sigma = 2.0f) + : TestCase(place, alias), + bboxes_dims_(bboxes_dims), + scores_dims_(scores_dims), + background_label_(background_label), + score_threshold_(score_threshold), + post_threshold_(post_threshold), + nms_top_k_(nms_top_k), + keep_top_k_(keep_top_k), + normalized_(normalized), + use_gaussian_(use_gaussian), + gaussian_sigma_(gaussian_sigma) {} + + void RunBaseline(Scope* scope) override { + auto* boxes = scope->FindTensor(bboxes_); + auto* scores = scope->FindTensor(scores_); + auto* outs = scope->NewTensor(out_); + auto* index = scope->NewTensor(index_); + + CHECK(outs); + outs->set_precision(PRECISION(kFloat)); + CHECK(index); + index->set_precision(PRECISION(kInt32)); + + auto score_dims = scores->dims(); + auto batch_size = score_dims[0]; + auto num_boxes = score_dims[2]; + auto box_dim = boxes->dims()[2]; + auto out_dim = box_dim + 2; + + Tensor boxes_slice, scores_slice; + size_t num_out = 0; + std::vector offsets = {0}; + std::vector detections; + std::vector indices; + detections.reserve(out_dim * num_boxes * batch_size); + indices.reserve(num_boxes * batch_size); + for (int i = 0; i < batch_size; ++i) { + scores_slice = scores->Slice(i, i + 1); + scores_slice.Resize({score_dims[1], score_dims[2]}); + boxes_slice = boxes->Slice(i, i + 1); + boxes_slice.Resize({score_dims[2], box_dim}); + int start = i * score_dims[2]; + num_out = MultiClassMatrixNMS(scores_slice, + boxes_slice, + &detections, + &indices, + start, + background_label_, + nms_top_k_, + keep_top_k_, + normalized_, + score_threshold_, + post_threshold_, + use_gaussian_, + gaussian_sigma_); + offsets.push_back(offsets.back() + num_out); + } + + size_t num_kept = offsets.back(); + if (num_kept == 0) { + outs->Resize({0, out_dim}); + index->Resize({0, 1}); + } else { + outs->Resize({static_cast(num_kept), out_dim}); + index->Resize({static_cast(num_kept), 1}); + std::copy( + detections.begin(), detections.end(), outs->mutable_data()); + std::copy(indices.begin(), indices.end(), index->mutable_data()); + } + + LoD lod; + lod.emplace_back(offsets); + outs->set_lod(lod); + index->set_lod(lod); + } + + void PrepareOpDesc(cpp::OpDesc* op_desc) { + op_desc->SetType(type_); + op_desc->SetInput("BBoxes", {bboxes_}); + op_desc->SetInput("Scores", {scores_}); + op_desc->SetOutput("Out", {out_}); + op_desc->SetOutput("Index", {index_}); + op_desc->SetAttr("background_label", background_label_); + op_desc->SetAttr("score_threshold", score_threshold_); + op_desc->SetAttr("post_threshold", post_threshold_); + op_desc->SetAttr("nms_top_k", nms_top_k_); + op_desc->SetAttr("keep_top_k", keep_top_k_); + op_desc->SetAttr("normalized", normalized_); + op_desc->SetAttr("use_gaussian", use_gaussian_); + op_desc->SetAttr("gaussian_sigma", gaussian_sigma_); + } + + void PrepareData() override { + std::vector bboxes(bboxes_dims_.production()); + for (int i = 0; i < bboxes_dims_.production(); ++i) { + bboxes[i] = i * 1. / bboxes_dims_.production(); + } + SetCommonTensor(bboxes_, bboxes_dims_, bboxes.data()); + + std::vector scores(scores_dims_.production()); + for (int i = 0; i < scores_dims_.production(); ++i) { + scores[i] = i * 1. / scores_dims_.production(); + } + SetCommonTensor(scores_, scores_dims_, scores.data()); + } +}; + +void TestMatrixNms(Place place, float abs_error) { + int N = 3; + int M = 2500; + for (int class_num : {2, 4, 10}) { + std::vector bbox_shape{N, M, 4}; + std::vector score_shape{N, class_num, M}; + std::unique_ptr tester(new MatrixNmsComputeTester( + place, "def", DDim(bbox_shape), DDim(score_shape))); + arena::Arena arena(std::move(tester), place, abs_error); + arena.TestPrecision(); + } +} + +TEST(matrix_nms, precision) { + float abs_error = 2e-5; + Place place; +#if defined(LITE_WITH_ARM) + place = TARGET(kHost); +#elif defined(LITE_WITH_XPU) && defined(LITE_WITH_XTCL) + place = TARGET(kXPU); +#else + return; +#endif + + TestMatrixNms(place, abs_error); +} + +} // namespace lite +} // namespace paddle -- GitLab