From 02414aac52fe2e4b25b64e1093883f010fb4fda6 Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Wed, 3 Aug 2022 10:46:18 +0800 Subject: [PATCH] Phi matrixnums (#44437) * phi_matrix_nms * remove old kernels and add optest check_eager * reoder args * reoder args in infermate * update * get back legacy dygraph --- .../operators/detection/matrix_nms_op.cc | 322 +----------------- paddle/phi/api/yaml/legacy_api.yaml | 8 + paddle/phi/infermeta/binary.cc | 58 ++++ paddle/phi/infermeta/binary.h | 15 + paddle/phi/kernels/cpu/matrix_nms_kernel.cc | 321 +++++++++++++++++ paddle/phi/kernels/matrix_nms_kernel.h | 37 ++ paddle/phi/ops/compat/matrix_nms_sig.cc | 35 ++ python/paddle/fluid/layers/detection.py | 16 +- .../tests/unittests/test_matrix_nms_op.py | 90 +++-- python/paddle/vision/ops.py | 10 + 10 files changed, 564 insertions(+), 348 deletions(-) create mode 100644 paddle/phi/kernels/cpu/matrix_nms_kernel.cc create mode 100644 paddle/phi/kernels/matrix_nms_kernel.h create mode 100644 paddle/phi/ops/compat/matrix_nms_sig.cc diff --git a/paddle/fluid/operators/detection/matrix_nms_op.cc b/paddle/fluid/operators/detection/matrix_nms_op.cc index 7a11c8ed2e..feacea63e3 100644 --- a/paddle/fluid/operators/detection/matrix_nms_op.cc +++ b/paddle/fluid/operators/detection/matrix_nms_op.cc @@ -11,9 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. limitations under the License. */ +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/detection/nms_util.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -25,55 +27,6 @@ class MatrixNMSOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", "MatrixNMS"); - OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", "MatrixNMS"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixNMS"); - auto box_dims = ctx->GetInputDim("BBoxes"); - auto score_dims = ctx->GetInputDim("Scores"); - auto score_size = score_dims.size(); - - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ(score_size == 3, - true, - platform::errors::InvalidArgument( - "The rank of Input(Scores) must be 3. " - "But received rank = %d.", - score_size)); - PADDLE_ENFORCE_EQ(box_dims.size(), - 3, - platform::errors::InvalidArgument( - "The rank of Input(BBoxes) must be 3." - "But received rank = %d.", - box_dims.size())); - PADDLE_ENFORCE_EQ(box_dims[2] == 4, - true, - platform::errors::InvalidArgument( - "The last dimension of Input (BBoxes) must be 4, " - "represents the layout of coordinate " - "[xmin, ymin, xmax, ymax].")); - PADDLE_ENFORCE_EQ( - box_dims[1], - score_dims[2], - platform::errors::InvalidArgument( - "The 2nd dimension of Input(BBoxes) must be equal to " - "last dimension of Input(Scores), which represents the " - "predicted bboxes." - "But received box_dims[1](%s) != socre_dims[2](%s)", - box_dims[1], - score_dims[2])); - } - ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2}); - ctx->SetOutputDim("Index", {box_dims[1], 1}); - if (ctx->HasOutput("RoisNum")) { - ctx->SetOutputDim("RoisNum", {-1}); - } - if (!ctx->IsRuntime()) { - ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1)); - ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1)); - } - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -83,266 +36,6 @@ class MatrixNMSOp : public framework::OperatorWithKernel { } }; -template -struct decay_score; - -template -struct decay_score { - T operator()(T iou, T max_iou, T sigma) { - return std::exp((max_iou * max_iou - iou * iou) * sigma); - } -}; - -template -struct decay_score { - T operator()(T iou, T max_iou, T sigma) { - return (1. - iou) / (1. - max_iou); - } -}; - -template -void NMSMatrix(const Tensor& bbox, - const Tensor& scores, - const T score_threshold, - const T post_threshold, - const float sigma, - const int64_t top_k, - const bool normalized, - std::vector* selected_indices, - std::vector* decayed_scores) { - int64_t num_boxes = bbox.dims()[0]; - int64_t box_size = bbox.dims()[1]; - - auto score_ptr = scores.data(); - auto bbox_ptr = bbox.data(); - - std::vector perm(num_boxes); - std::iota(perm.begin(), perm.end(), 0); - auto end = std::remove_if( - perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) { - return score_ptr[idx] <= score_threshold; - }); - - auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) { - return score_ptr[lhs] > score_ptr[rhs]; - }; - - int64_t num_pre = std::distance(perm.begin(), end); - if (num_pre <= 0) { - return; - } - if (top_k > -1 && num_pre > top_k) { - num_pre = top_k; - } - std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn); - - std::vector iou_matrix((num_pre * (num_pre - 1)) >> 1); - std::vector iou_max(num_pre); - - iou_max[0] = 0.; - for (int64_t i = 1; i < num_pre; i++) { - T max_iou = 0.; - auto idx_a = perm[i]; - for (int64_t j = 0; j < i; j++) { - auto idx_b = perm[j]; - auto iou = JaccardOverlap( - bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized); - max_iou = std::max(max_iou, iou); - iou_matrix[i * (i - 1) / 2 + j] = iou; - } - iou_max[i] = max_iou; - } - - if (score_ptr[perm[0]] > post_threshold) { - selected_indices->push_back(perm[0]); - decayed_scores->push_back(score_ptr[perm[0]]); - } - - decay_score decay_fn; - for (int64_t i = 1; i < num_pre; i++) { - T min_decay = 1.; - for (int64_t j = 0; j < i; j++) { - auto max_iou = iou_max[j]; - auto iou = iou_matrix[i * (i - 1) / 2 + j]; - auto decay = decay_fn(iou, max_iou, sigma); - min_decay = std::min(min_decay, decay); - } - auto ds = min_decay * score_ptr[perm[i]]; - if (ds <= post_threshold) continue; - selected_indices->push_back(perm[i]); - decayed_scores->push_back(ds); - } -} - -template -class MatrixNMSKernel : public framework::OpKernel { - public: - size_t MultiClassMatrixNMS(const Tensor& scores, - const Tensor& bboxes, - std::vector* out, - std::vector* indices, - int start, - int64_t background_label, - int64_t nms_top_k, - int64_t keep_top_k, - bool normalized, - T score_threshold, - T post_threshold, - bool use_gaussian, - float gaussian_sigma) const { - std::vector all_indices; - std::vector all_scores; - std::vector all_classes; - all_indices.reserve(scores.numel()); - all_scores.reserve(scores.numel()); - all_classes.reserve(scores.numel()); - - size_t num_det = 0; - auto class_num = scores.dims()[0]; - Tensor score_slice; - for (int64_t c = 0; c < class_num; ++c) { - if (c == background_label) continue; - score_slice = scores.Slice(c, c + 1); - if (use_gaussian) { - NMSMatrix(bboxes, - score_slice, - score_threshold, - post_threshold, - gaussian_sigma, - nms_top_k, - normalized, - &all_indices, - &all_scores); - } else { - NMSMatrix(bboxes, - score_slice, - score_threshold, - post_threshold, - gaussian_sigma, - nms_top_k, - normalized, - &all_indices, - &all_scores); - } - for (size_t i = 0; i < all_indices.size() - num_det; i++) { - all_classes.push_back(static_cast(c)); - } - num_det = all_indices.size(); - } - - if (num_det <= 0) { - return num_det; - } - - if (keep_top_k > -1) { - auto k = static_cast(keep_top_k); - if (num_det > k) num_det = k; - } - - std::vector perm(all_indices.size()); - std::iota(perm.begin(), perm.end(), 0); - - std::partial_sort(perm.begin(), - perm.begin() + num_det, - perm.end(), - [&all_scores](int lhs, int rhs) { - return all_scores[lhs] > all_scores[rhs]; - }); - - for (size_t i = 0; i < num_det; i++) { - auto p = perm[i]; - auto idx = all_indices[p]; - auto cls = all_classes[p]; - auto score = all_scores[p]; - auto bbox = bboxes.data() + idx * bboxes.dims()[1]; - (*indices).push_back(start + idx); - (*out).push_back(cls); - (*out).push_back(score); - for (int j = 0; j < bboxes.dims()[1]; j++) { - (*out).push_back(bbox[j]); - } - } - - return num_det; - } - - void Compute(const framework::ExecutionContext& ctx) const override { - auto* boxes = ctx.Input("BBoxes"); - auto* scores = ctx.Input("Scores"); - auto* outs = ctx.Output("Out"); - auto* index = ctx.Output("Index"); - - auto background_label = ctx.Attr("background_label"); - auto nms_top_k = ctx.Attr("nms_top_k"); - auto keep_top_k = ctx.Attr("keep_top_k"); - auto normalized = ctx.Attr("normalized"); - auto score_threshold = ctx.Attr("score_threshold"); - auto post_threshold = ctx.Attr("post_threshold"); - auto use_gaussian = ctx.Attr("use_gaussian"); - auto gaussian_sigma = ctx.Attr("gaussian_sigma"); - - auto score_dims = scores->dims(); - auto batch_size = score_dims[0]; - auto num_boxes = score_dims[2]; - auto box_dim = boxes->dims()[2]; - auto out_dim = box_dim + 2; - - Tensor boxes_slice, scores_slice; - size_t num_out = 0; - std::vector offsets = {0}; - std::vector detections; - std::vector indices; - std::vector num_per_batch; - detections.reserve(out_dim * num_boxes * batch_size); - indices.reserve(num_boxes * batch_size); - num_per_batch.reserve(batch_size); - for (int i = 0; i < batch_size; ++i) { - scores_slice = scores->Slice(i, i + 1); - scores_slice.Resize({score_dims[1], score_dims[2]}); - boxes_slice = boxes->Slice(i, i + 1); - boxes_slice.Resize({score_dims[2], box_dim}); - int start = i * score_dims[2]; - num_out = MultiClassMatrixNMS(scores_slice, - boxes_slice, - &detections, - &indices, - start, - background_label, - nms_top_k, - keep_top_k, - normalized, - score_threshold, - post_threshold, - use_gaussian, - gaussian_sigma); - offsets.push_back(offsets.back() + num_out); - num_per_batch.emplace_back(num_out); - } - - int64_t num_kept = offsets.back(); - if (num_kept == 0) { - outs->mutable_data({0, out_dim}, ctx.GetPlace()); - index->mutable_data({0, 1}, ctx.GetPlace()); - } else { - outs->mutable_data({num_kept, out_dim}, ctx.GetPlace()); - index->mutable_data({num_kept, 1}, ctx.GetPlace()); - std::copy(detections.begin(), detections.end(), outs->data()); - std::copy(indices.begin(), indices.end(), index->data()); - } - - if (ctx.HasOutput("RoisNum")) { - auto* rois_num = ctx.Output("RoisNum"); - rois_num->mutable_data({batch_size}, ctx.GetPlace()); - std::copy( - num_per_batch.begin(), num_per_batch.end(), rois_num->data()); - } - framework::LoD lod; - lod.emplace_back(offsets); - outs->set_lod(lod); - index->set_lod(lod); - } -}; - class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -435,16 +128,19 @@ https://arxiv.org/abs/2003.10152 } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(matrix_nms, + MatrixNMSInferShapeFunctor, + PD_INFER_META(phi::MatrixNMSInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR( matrix_nms, ops::MatrixNMSOp, ops::MatrixNMSOpMaker, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); -REGISTER_OP_CPU_KERNEL(matrix_nms, - ops::MatrixNMSKernel, - ops::MatrixNMSKernel); + paddle::framework::EmptyGradOpMaker, + MatrixNMSInferShapeFunctor); + REGISTER_OP_VERSION(matrix_nms) .AddCheckpoint(R"ROC(Upgrade matrix_nms: add a new output [RoisNum].)ROC", paddle::framework::compatible::OpVersionDesc().NewOutput( diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 3bd11fa8cd..e38f77613f 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -1501,6 +1501,14 @@ func : matmul backward : matmul_grad +- api : matrix_nms + args : (Tensor bboxes, Tensor scores, float score_threshold, int nms_top_k, int keep_top_k, float post_threshold=0., bool use_gaussian = false, float gaussian_sigma = 2.0, int background_label = 0, bool normalized = true) + output : Tensor(out), Tensor(index), Tensor(roisnum) + infer_meta : + func : MatrixNMSInferMeta + kernel : + func : matrix_nms + # matrix_power - api : matrix_power args : (Tensor x, int n) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ebcc6e28b4..566a2a953d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1687,6 +1687,64 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, out->share_lod(x); } +void MatrixNMSInferMeta(const MetaTensor& bboxes, + const MetaTensor& scores, + float score_threshold, + int nms_top_k, + int keep_top_k, + float post_threshold, + bool use_gaussian, + float gaussian_sigma, + int background_label, + bool normalized, + MetaTensor* out, + MetaTensor* index, + MetaTensor* roisnum, + MetaConfig config) { + auto box_dims = bboxes.dims(); + auto score_dims = scores.dims(); + auto score_size = score_dims.size(); + + if (config.is_runtime) { + PADDLE_ENFORCE_EQ( + score_size == 3, + true, + errors::InvalidArgument("The rank of Input(Scores) must be 3. " + "But received rank = %d.", + score_size)); + PADDLE_ENFORCE_EQ( + box_dims.size(), + 3, + errors::InvalidArgument("The rank of Input(BBoxes) must be 3." + "But received rank = %d.", + box_dims.size())); + PADDLE_ENFORCE_EQ(box_dims[2] == 4, + true, + errors::InvalidArgument( + "The last dimension of Input (BBoxes) must be 4, " + "represents the layout of coordinate " + "[xmin, ymin, xmax, ymax].")); + PADDLE_ENFORCE_EQ( + box_dims[1], + score_dims[2], + errors::InvalidArgument( + "The 2nd dimension of Input(BBoxes) must be equal to " + "last dimension of Input(Scores), which represents the " + "predicted bboxes." + "But received box_dims[1](%s) != socre_dims[2](%s)", + box_dims[1], + score_dims[2])); + } + out->set_dims({box_dims[1], box_dims[2] + 2}); + out->set_dtype(bboxes.dtype()); + index->set_dims({box_dims[1], 1}); + index->set_dtype(phi::DataType::INT32); + if (roisnum != nullptr) { + roisnum->set_dims({-1}); + roisnum->set_dtype(phi::DataType::INT32); + } +} + void MatrixRankTolInferMeta(const MetaTensor& x, const MetaTensor& atol_tensor, bool use_default_tol, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 3662a7c316..3b03ce01a7 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -249,6 +249,21 @@ void MatmulWithFlattenInferMeta(const MetaTensor& x, int y_num_col_dims, MetaTensor* out); +void MatrixNMSInferMeta(const MetaTensor& bboxes, + const MetaTensor& scores, + float score_threshold, + int nms_top_k, + int keep_top_k, + float post_threshold, + bool use_gaussian, + float gaussian_sigma, + int background_label, + bool normalized, + MetaTensor* out, + MetaTensor* index, + MetaTensor* roisnum, + MetaConfig config = MetaConfig()); + void MatrixRankTolInferMeta(const MetaTensor& x, const MetaTensor& atol_tensor, bool use_default_tol, diff --git a/paddle/phi/kernels/cpu/matrix_nms_kernel.cc b/paddle/phi/kernels/cpu/matrix_nms_kernel.cc new file mode 100644 index 0000000000..aa9f778d1e --- /dev/null +++ b/paddle/phi/kernels/cpu/matrix_nms_kernel.cc @@ -0,0 +1,321 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/matrix_nms_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +static inline T BBoxArea(const T* box, const bool normalized) { + if (box[2] < box[0] || box[3] < box[1]) { + // If coordinate values are is invalid + // (e.g. xmax < xmin or ymax < ymin), return 0. + return static_cast(0.); + } else { + const T w = box[2] - box[0]; + const T h = box[3] - box[1]; + if (normalized) { + return w * h; + } else { + // If coordinate values are not within range [0, 1]. + return (w + 1) * (h + 1); + } + } +} + +template +static inline T JaccardOverlap(const T* box1, + const T* box2, + const bool normalized) { + if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] || + box2[3] < box1[1]) { + return static_cast(0.); + } else { + const T inter_xmin = std::max(box1[0], box2[0]); + const T inter_ymin = std::max(box1[1], box2[1]); + const T inter_xmax = std::min(box1[2], box2[2]); + const T inter_ymax = std::min(box1[3], box2[3]); + T norm = normalized ? static_cast(0.) : static_cast(1.); + T inter_w = inter_xmax - inter_xmin + norm; + T inter_h = inter_ymax - inter_ymin + norm; + const T inter_area = inter_w * inter_h; + const T bbox1_area = BBoxArea(box1, normalized); + const T bbox2_area = BBoxArea(box2, normalized); + return inter_area / (bbox1_area + bbox2_area - inter_area); + } +} + +template +struct decay_score; + +template +struct decay_score { + T operator()(T iou, T max_iou, T sigma) { + return std::exp((max_iou * max_iou - iou * iou) * sigma); + } +}; + +template +struct decay_score { + T operator()(T iou, T max_iou, T sigma) { + return (1. - iou) / (1. - max_iou); + } +}; + +template +void NMSMatrix(const DenseTensor& bbox, + const DenseTensor& scores, + const T score_threshold, + const T post_threshold, + const float sigma, + const int64_t top_k, + const bool normalized, + std::vector* selected_indices, + std::vector* decayed_scores) { + int64_t num_boxes = bbox.dims()[0]; + int64_t box_size = bbox.dims()[1]; + + auto score_ptr = scores.data(); + auto bbox_ptr = bbox.data(); + + std::vector perm(num_boxes); + std::iota(perm.begin(), perm.end(), 0); + auto end = std::remove_if( + perm.begin(), perm.end(), [&score_ptr, score_threshold](int32_t idx) { + return score_ptr[idx] <= score_threshold; + }); + + auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) { + return score_ptr[lhs] > score_ptr[rhs]; + }; + + int64_t num_pre = std::distance(perm.begin(), end); + if (num_pre <= 0) { + return; + } + if (top_k > -1 && num_pre > top_k) { + num_pre = top_k; + } + std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn); + + std::vector iou_matrix((num_pre * (num_pre - 1)) >> 1); + std::vector iou_max(num_pre); + + iou_max[0] = 0.; + for (int64_t i = 1; i < num_pre; i++) { + T max_iou = 0.; + auto idx_a = perm[i]; + for (int64_t j = 0; j < i; j++) { + auto idx_b = perm[j]; + auto iou = JaccardOverlap( + bbox_ptr + idx_a * box_size, bbox_ptr + idx_b * box_size, normalized); + max_iou = std::max(max_iou, iou); + iou_matrix[i * (i - 1) / 2 + j] = iou; + } + iou_max[i] = max_iou; + } + + if (score_ptr[perm[0]] > post_threshold) { + selected_indices->push_back(perm[0]); + decayed_scores->push_back(score_ptr[perm[0]]); + } + + decay_score decay_fn; + for (int64_t i = 1; i < num_pre; i++) { + T min_decay = 1.; + for (int64_t j = 0; j < i; j++) { + auto max_iou = iou_max[j]; + auto iou = iou_matrix[i * (i - 1) / 2 + j]; + auto decay = decay_fn(iou, max_iou, sigma); + min_decay = std::min(min_decay, decay); + } + auto ds = min_decay * score_ptr[perm[i]]; + if (ds <= post_threshold) continue; + selected_indices->push_back(perm[i]); + decayed_scores->push_back(ds); + } +} + +template +size_t MultiClassMatrixNMS(const DenseTensor& scores, + const DenseTensor& bboxes, + std::vector* out, + std::vector* indices, + int start, + int64_t background_label, + int64_t nms_top_k, + int64_t keep_top_k, + bool normalized, + T score_threshold, + T post_threshold, + bool use_gaussian, + float gaussian_sigma) { + std::vector all_indices; + std::vector all_scores; + std::vector all_classes; + all_indices.reserve(scores.numel()); + all_scores.reserve(scores.numel()); + all_classes.reserve(scores.numel()); + + size_t num_det = 0; + auto class_num = scores.dims()[0]; + DenseTensor score_slice; + for (int64_t c = 0; c < class_num; ++c) { + if (c == background_label) continue; + score_slice = scores.Slice(c, c + 1); + if (use_gaussian) { + NMSMatrix(bboxes, + score_slice, + score_threshold, + post_threshold, + gaussian_sigma, + nms_top_k, + normalized, + &all_indices, + &all_scores); + } else { + NMSMatrix(bboxes, + score_slice, + score_threshold, + post_threshold, + gaussian_sigma, + nms_top_k, + normalized, + &all_indices, + &all_scores); + } + for (size_t i = 0; i < all_indices.size() - num_det; i++) { + all_classes.push_back(static_cast(c)); + } + num_det = all_indices.size(); + } + + if (num_det <= 0) { + return num_det; + } + + if (keep_top_k > -1) { + auto k = static_cast(keep_top_k); + if (num_det > k) num_det = k; + } + + std::vector perm(all_indices.size()); + std::iota(perm.begin(), perm.end(), 0); + + std::partial_sort(perm.begin(), + perm.begin() + num_det, + perm.end(), + [&all_scores](int lhs, int rhs) { + return all_scores[lhs] > all_scores[rhs]; + }); + + for (size_t i = 0; i < num_det; i++) { + auto p = perm[i]; + auto idx = all_indices[p]; + auto cls = all_classes[p]; + auto score = all_scores[p]; + auto bbox = bboxes.data() + idx * bboxes.dims()[1]; + (*indices).push_back(start + idx); + (*out).push_back(cls); + (*out).push_back(score); + for (int j = 0; j < bboxes.dims()[1]; j++) { + (*out).push_back(bbox[j]); + } + } + + return num_det; +} + +template +void MatrixNMSKernel(const Context& ctx, + const DenseTensor& bboxes, + const DenseTensor& scores, + float score_threshold, + int nms_top_k, + int keep_top_k, + float post_threshold, + bool use_gaussian, + float gaussian_sigma, + int background_label, + bool normalized, + DenseTensor* out, + DenseTensor* index, + DenseTensor* roisnum) { + auto score_dims = scores.dims(); + auto batch_size = score_dims[0]; + auto num_boxes = score_dims[2]; + auto box_dim = bboxes.dims()[2]; + auto out_dim = box_dim + 2; + + DenseTensor boxes_slice, scores_slice; + size_t num_out = 0; + std::vector offsets = {0}; + std::vector detections; + std::vector indices; + std::vector num_per_batch; + detections.reserve(out_dim * num_boxes * batch_size); + indices.reserve(num_boxes * batch_size); + num_per_batch.reserve(batch_size); + for (int i = 0; i < batch_size; ++i) { + scores_slice = scores.Slice(i, i + 1); + scores_slice.Resize({score_dims[1], score_dims[2]}); + boxes_slice = bboxes.Slice(i, i + 1); + boxes_slice.Resize({score_dims[2], box_dim}); + int start = i * score_dims[2]; + num_out = MultiClassMatrixNMS(scores_slice, + boxes_slice, + &detections, + &indices, + start, + background_label, + nms_top_k, + keep_top_k, + normalized, + static_cast(score_threshold), + static_cast(post_threshold), + use_gaussian, + gaussian_sigma); + offsets.push_back(offsets.back() + num_out); + num_per_batch.emplace_back(num_out); + } + + int64_t num_kept = offsets.back(); + if (num_kept == 0) { + out->Resize(phi::make_ddim({0, out_dim})); + ctx.template Alloc(out); + index->Resize(phi::make_ddim({0, 1})); + ctx.template Alloc(index); + } else { + out->Resize(phi::make_ddim({num_kept, out_dim})); + ctx.template Alloc(out); + index->Resize(phi::make_ddim({num_kept, 1})); + ctx.template Alloc(index); + std::copy(detections.begin(), detections.end(), out->data()); + std::copy(indices.begin(), indices.end(), index->data()); + } + + if (roisnum != nullptr) { + roisnum->Resize(phi::make_ddim({batch_size})); + ctx.template Alloc(roisnum); + std::copy(num_per_batch.begin(), num_per_batch.end(), roisnum->data()); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + matrix_nms, CPU, ALL_LAYOUT, phi::MatrixNMSKernel, float, double) {} diff --git a/paddle/phi/kernels/matrix_nms_kernel.h b/paddle/phi/kernels/matrix_nms_kernel.h new file mode 100644 index 0000000000..895829a712 --- /dev/null +++ b/paddle/phi/kernels/matrix_nms_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MatrixNMSKernel(const Context& ctx, + const DenseTensor& bboxes, + const DenseTensor& scores, + float score_threshold, + int nms_top_k, + int keep_top_k, + float post_threshold, + bool use_gaussian, + float gaussian_sigma, + int background_label, + bool normalized, + DenseTensor* out, + DenseTensor* index, + DenseTensor* roisnum); + +} // namespace phi diff --git a/paddle/phi/ops/compat/matrix_nms_sig.cc b/paddle/phi/ops/compat/matrix_nms_sig.cc new file mode 100644 index 0000000000..27ebeca7df --- /dev/null +++ b/paddle/phi/ops/compat/matrix_nms_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MatrixNMSOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("matrix_nms", + {"BBoxes", "Scores"}, + {"score_threshold", + "nms_top_k", + "keep_top_k", + "post_threshold", + "use_gaussian", + "gaussian_sigma", + "background_label", + "normalized"}, + {"Out", "Index", "RoisNum"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(matrix_nms, phi::MatrixNMSOpArgumentMapping); diff --git a/python/paddle/fluid/layers/detection.py b/python/paddle/fluid/layers/detection.py index 3540f69c04..4c8d219b27 100644 --- a/python/paddle/fluid/layers/detection.py +++ b/python/paddle/fluid/layers/detection.py @@ -3642,6 +3642,16 @@ def matrix_nms(bboxes, keep_top_k=200, normalized=False) """ + if in_dygraph_mode(): + attrs = (score_threshold, nms_top_k, keep_top_k, post_threshold, + use_gaussian, gaussian_sigma, background_label, normalized) + + out, index = _C_ops.final_state_matrix_nms(bboxes, scores, *attrs) + if return_index: + return out, index + else: + return out + check_variable_and_dtype(bboxes, 'BBoxes', ['float32', 'float64'], 'matrix_nms') check_variable_and_dtype(scores, 'Scores', ['float32', 'float64'], @@ -3664,13 +3674,13 @@ def matrix_nms(bboxes, 'Scores': scores }, attrs={ - 'background_label': background_label, 'score_threshold': score_threshold, 'post_threshold': post_threshold, 'nms_top_k': nms_top_k, - 'gaussian_sigma': gaussian_sigma, - 'use_gaussian': use_gaussian, 'keep_top_k': keep_top_k, + 'use_gaussian': use_gaussian, + 'gaussian_sigma': gaussian_sigma, + 'background_label': background_label, 'normalized': normalized }, outputs={ diff --git a/python/paddle/fluid/tests/unittests/test_matrix_nms_op.py b/python/paddle/fluid/tests/unittests/test_matrix_nms_op.py index c85b715b0c..aa5d5a08bf 100644 --- a/python/paddle/fluid/tests/unittests/test_matrix_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_matrix_nms_op.py @@ -22,6 +22,29 @@ from paddle.fluid import Program, program_guard import paddle +def python_matrix_nms(bboxes, + scores, + score_threshold, + nms_top_k, + keep_top_k, + post_threshold, + use_gaussian=False, + gaussian_sigma=2., + background_label=0, + normalized=True, + return_index=True, + return_rois_num=True): + out, rois_num, index = paddle.vision.ops.matrix_nms( + bboxes, scores, score_threshold, post_threshold, nms_top_k, keep_top_k, + use_gaussian, gaussian_sigma, background_label, normalized, + return_index, return_rois_num) + if not return_index: + index = None + if not return_rois_num: + rois_num = None + return out, index, rois_num + + def softmax(x): # clip to shiftx, otherwise, when calc loss with # log(exp(shiftx)), may get log(0)=INF @@ -167,6 +190,7 @@ class TestMatrixNMSOp(OpTest): def setUp(self): self.set_argument() + self.python_api = python_matrix_nms N = 7 M = 1200 C = 21 @@ -203,23 +227,23 @@ class TestMatrixNMSOp(OpTest): self.op_type = 'matrix_nms' self.inputs = {'BBoxes': boxes, 'Scores': scores} self.outputs = { - 'Out': (nmsed_outs, [lod]), - 'Index': (index_outs[:, None], [lod]), + 'Out': nmsed_outs, + 'Index': index_outs[:, None], 'RoisNum': np.array(lod).astype('int32') } self.attrs = { - 'background_label': 0, + 'score_threshold': score_threshold, 'nms_top_k': nms_top_k, 'keep_top_k': keep_top_k, - 'score_threshold': score_threshold, 'post_threshold': post_threshold, 'use_gaussian': use_gaussian, 'gaussian_sigma': gaussian_sigma, + 'background_label': 0, 'normalized': True, } def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) class TestMatrixNMSOpNoOutput(TestMatrixNMSOp): @@ -265,50 +289,51 @@ class TestMatrixNMSError(unittest.TestCase): # the bboxes type must be Variable fluid.layers.matrix_nms(bboxes=boxes_np, scores=scores_data, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, score_threshold=score_threshold, - post_threshold=post_threshold) + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) paddle.vision.ops.matrix_nms(bboxes=boxes_np, scores=scores_data, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, score_threshold=score_threshold, - post_threshold=post_threshold) + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) def test_scores_Variable(): # the scores type must be Variable fluid.layers.matrix_nms(bboxes=boxes_data, scores=scores_np, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, score_threshold=score_threshold, - post_threshold=post_threshold) + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) paddle.vision.ops.matrix_nms(bboxes=boxes_data, scores=scores_np, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, score_threshold=score_threshold, - post_threshold=post_threshold) + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) def test_empty(): # when all score are lower than threshold try: fluid.layers.matrix_nms(bboxes=boxes_data, scores=scores_data, + score_threshold=score_threshold, + post_threshold=post_threshold, nms_top_k=nms_top_k, - keep_top_k=keep_top_k, - score_threshold=10., - post_threshold=post_threshold) + keep_top_k=keep_top_k) except Exception as e: self.fail(e) try: - paddle.vision.ops.matrix_nms(bboxes=boxes_data, - scores=scores_data, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, - score_threshold=10., - post_threshold=post_threshold) + paddle.vision.ops.matrix_nms( + bboxes=boxes_data, + scores=scores_data, + score_threshold=score_threshold, + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) except Exception as e: self.fail(e) @@ -317,20 +342,20 @@ class TestMatrixNMSError(unittest.TestCase): try: fluid.layers.matrix_nms(bboxes=boxes_data, scores=scores_data, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, score_threshold=score_threshold, - post_threshold=post_threshold) + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) except Exception as e: self.fail(e) try: paddle.vision.ops.matrix_nms( bboxes=boxes_data, scores=scores_data, - nms_top_k=nms_top_k, - keep_top_k=keep_top_k, score_threshold=score_threshold, - post_threshold=post_threshold) + post_threshold=post_threshold, + nms_top_k=nms_top_k, + keep_top_k=keep_top_k) except Exception as e: self.fail(e) @@ -340,4 +365,5 @@ class TestMatrixNMSError(unittest.TestCase): if __name__ == '__main__': + paddle.enable_static() unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index c197d428ff..20ea141785 100755 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -1891,6 +1891,16 @@ def matrix_nms(bboxes, check_type(background_label, 'background_label', int, 'matrix_nms') if in_dygraph_mode(): + out, index, rois_num = _C_ops.final_state_matrix_nms( + bboxes, scores, score_threshold, nms_top_k, keep_top_k, + post_threshold, use_gaussian, gaussian_sigma, background_label, + normalized) + if not return_index: + index = None + if not return_rois_num: + rois_num = None + return out, rois_num, index + elif _in_legacy_dygraph(): attrs = ('background_label', background_label, 'score_threshold', score_threshold, 'post_threshold', post_threshold, 'nms_top_k', nms_top_k, 'gaussian_sigma', gaussian_sigma, 'use_gaussian', -- GitLab