diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 568c7982cfc7c07b9c7f840ccaa32e4025225122..f10c8019199993f4b1f8d880edc2185bf33082d0 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -66,6 +66,7 @@ detection_library(yolo_box_op SRCS yolo_box_op.cc) detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu) detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) +detection_library(nms_op SRCS nms_op.cc nms_op.cu) if(WITH_GPU OR WITH_ROCM) set(TMPDEPS memory) diff --git a/paddle/fluid/operators/detection/nms_op.cc b/paddle/fluid/operators/detection/nms_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f6dc44eb5fc2d969c4b4a379c9c4f95167613730 --- /dev/null +++ b/paddle/fluid/operators/detection/nms_op.cc @@ -0,0 +1,147 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/detection/nms_op.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class NMSOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Boxes", + "(Tensor) " + "Boxes is a Tensor with shape [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2] " + "the relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``."); + + AddOutput("KeepBoxesIdxs", + "(Tensor) " + "KeepBoxesIdxs is a Tensor with shape [N] "); + AddAttr( + "iou_threshold", + "iou_threshold is a threshold value used to compress similar boxes " + "boxes with IoU > iou_threshold will be considered as overlapping " + "and just one of them can be kept.") + .SetDefault(1.0f) + .AddCustomChecker([](const float& iou_threshold) { + PADDLE_ENFORCE_LE(iou_threshold, 1.0f, + platform::errors::InvalidArgument( + "iou_threshold should less equal than 1.0 " + "but got %f", + iou_threshold)); + PADDLE_ENFORCE_GE(iou_threshold, 0.0f, + platform::errors::InvalidArgument( + "iou_threshold should greater equal than 0.0 " + "but got %f", + iou_threshold)); + }); + AddComment(R"DOC( + NMS Operator. + This Operator is used to perform Non-Maximum Compress for input boxes. + Indices of boxes kept by NMS will be sorted by scores and output. + )DOC"); + } +}; + +class NMSOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS"); + OP_INOUT_CHECK(ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs", + "NMS"); + + auto boxes_dim = ctx->GetInputDim("Boxes"); + PADDLE_ENFORCE_EQ(boxes_dim.size(), 2, + platform::errors::InvalidArgument( + "The Input Boxes must be 2-dimention " + "whose shape must be [N, 4] " + "N is the number of boxes " + "in last dimension in format [x1, x2, y1, y2]. ")); + auto num_boxes = boxes_dim[0]; + + ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace()); + } +}; + +template +static void NMS(const T* boxes_data, int64_t* output_data, float threshold, + int64_t num_boxes) { + auto num_masks = CeilDivide(num_boxes, 64); + std::vector masks(num_masks, 0); + + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + T box_1[4]; + for (int k = 0; k < 4; ++k) { + box_1[k] = boxes_data[i * 4 + k]; + } + for (int64_t j = i + 1; j < num_boxes; ++j) { + if (masks[j / 64] & 1ULL << (j % 64)) continue; + T box_2[4]; + for (int k = 0; k < 4; ++k) { + box_2[k] = boxes_data[j * 4 + k]; + } + bool is_overlap = CalculateIoU(box_1, box_2, threshold); + if (is_overlap) { + masks[j / 64] |= 1ULL << (j % 64); + } + } + } + + int64_t output_data_idx = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + if (masks[i / 64] & 1ULL << (i % 64)) continue; + output_data[output_data_idx++] = i; + } + + for (; output_data_idx < num_boxes; ++output_data_idx) { + output_data[output_data_idx] = 0; + } +} + +template +class NMSKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* boxes = context.Input("Boxes"); + Tensor* output = context.Output("KeepBoxesIdxs"); + int64_t* output_data = output->mutable_data(context.GetPlace()); + auto threshold = context.template Attr("iou_threshold"); + NMS(boxes->data(), output_data, threshold, boxes->dims()[0]); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR( + nms, ops::NMSOp, ops::NMSOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel, ops::NMSKernel); diff --git a/paddle/fluid/operators/detection/nms_op.cu b/paddle/fluid/operators/detection/nms_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b6027e67d6ced6f825f5f68383f4baa8ccb4bc9b --- /dev/null +++ b/paddle/fluid/operators/detection/nms_op.cu @@ -0,0 +1,108 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/detection/nms_op.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +static const int64_t threadsPerBlock = sizeof(int64_t) * 8; + +namespace paddle { +namespace operators { + +using framework::Tensor; + +template +static __global__ void NMS(const T* boxes_data, float threshold, + int64_t num_boxes, uint64_t* masks) { + auto raw_start = blockIdx.y; + auto col_start = blockIdx.x; + if (raw_start > col_start) return; + + const int raw_last_storage = + min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock); + const int col_last_storage = + min(num_boxes - col_start * threadsPerBlock, threadsPerBlock); + + if (threadIdx.x < raw_last_storage) { + uint64_t mask = 0; + auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x; + const T* current_box = boxes_data + current_box_idx * 4; + for (int i = 0; i < col_last_storage; ++i) { + const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4; + if (CalculateIoU(current_box, target_box, threshold)) { + mask |= 1ULL << i; + } + } + const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + masks[current_box_idx * blocks_per_line + col_start] = mask; + } +} + +template +class NMSCudaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* boxes = context.Input("Boxes"); + Tensor* output = context.Output("KeepBoxesIdxs"); + auto* output_data = output->mutable_data(context.GetPlace()); + + auto threshold = context.template Attr("iou_threshold"); + const int64_t num_boxes = boxes->dims()[0]; + const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); + + dim3 block(threadsPerBlock); + dim3 grid(blocks_per_line, blocks_per_line); + + auto mask_data = + memory::Alloc(context.cuda_device_context(), + num_boxes * blocks_per_line * sizeof(uint64_t)); + uint64_t* mask_dev = reinterpret_cast(mask_data->ptr()); + NMS<<>>( + boxes->data(), threshold, num_boxes, mask_dev); + + std::vector mask_host(num_boxes * blocks_per_line); + memory::Copy(platform::CPUPlace(), mask_host.data(), context.GetPlace(), + mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t), + context.cuda_device_context().stream()); + + std::vector remv(blocks_per_line); + + std::vector keep_boxes_idxs(num_boxes); + int64_t* output_host = keep_boxes_idxs.data(); + + int64_t last_box_num = 0; + for (int64_t i = 0; i < num_boxes; ++i) { + auto remv_element_id = i / threadsPerBlock; + auto remv_bit_id = i % threadsPerBlock; + if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { + output_host[last_box_num++] = i; + uint64_t* current_mask = mask_host.data() + i * blocks_per_line; + for (auto j = remv_element_id; j < blocks_per_line; ++j) { + remv[j] |= current_mask[j]; + } + } + } + memory::Copy(context.GetPlace(), output_data, platform::CPUPlace(), + output_host, sizeof(int64_t) * num_boxes, + context.cuda_device_context().stream()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(nms, ops::NMSCudaKernel, + ops::NMSCudaKernel); diff --git a/paddle/fluid/operators/detection/nms_op.h b/paddle/fluid/operators/detection/nms_op.h new file mode 100644 index 0000000000000000000000000000000000000000..dce8f47f0174efb3c77f478b828aee1abe17fb9f --- /dev/null +++ b/paddle/fluid/operators/detection/nms_op.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) { + return (n + m - 1) / m; +} + +template +HOSTDEVICE inline bool CalculateIoU(const T* const box_1, const T* const box_2, + const float threshold) { + auto box_1_x0 = box_1[0], box_1_y0 = box_1[1]; + auto box_1_x1 = box_1[2], box_1_y1 = box_1[3]; + auto box_2_x0 = box_2[0], box_2_y0 = box_2[1]; + auto box_2_x1 = box_2[2], box_2_y1 = box_2[3]; + + auto inter_box_x0 = box_1_x0 > box_2_x0 ? box_1_x0 : box_2_x0; + auto inter_box_y0 = box_1_y0 > box_2_y0 ? box_1_y0 : box_2_y0; + auto inter_box_x1 = box_1_x1 < box_2_x1 ? box_1_x1 : box_2_x1; + auto inter_box_y1 = box_1_y1 < box_2_y1 ? box_1_y1 : box_2_y1; + + auto inter_width = + inter_box_x1 - inter_box_x0 > 0 ? inter_box_x1 - inter_box_x0 : 0; + auto inter_height = + inter_box_y1 - inter_box_y0 > 0 ? inter_box_y1 - inter_box_y0 : 0; + auto inter_area = inter_width * inter_height; + auto union_area = (box_1_x1 - box_1_x0) * (box_1_y1 - box_1_y0) + + (box_2_x1 - box_2_x0) * (box_2_y1 - box_2_y0) - inter_area; + return inter_area / union_area > threshold; +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 8b84a9c524adf583b1c02496f8bec09c32b10678..b4d6f9b941d4fb6cd53afcdbbbda3980787b164d 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -234,6 +234,7 @@ endif() if(WIN32) LIST(REMOVE_ITEM TEST_OPS test_complex_matmul) + LIST(REMOVE_ITEM TEST_OPS test_ops_nms) endif() LIST(REMOVE_ITEM TEST_OPS test_fleet_checkpoint) diff --git a/python/paddle/fluid/tests/unittests/test_nms_op.py b/python/paddle/fluid/tests/unittests/test_nms_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1b5ac1f1337d09506a0ee08707fc195383b78e18 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nms_op.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +def iou(box_a, box_b): + """Apply intersection-over-union overlap between box_a and box_b + """ + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + + area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a) + area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + return iou_ratio + + +def nms(boxes, nms_threshold): + selected_indices = np.zeros(boxes.shape[0], dtype=np.int64) + keep = np.ones(boxes.shape[0], dtype=int) + io_ratio = np.ones((boxes.shape[0], boxes.shape[0]), dtype=np.float64) + cnt = 0 + for i in range(boxes.shape[0]): + if keep[i] == 0: + continue + selected_indices[cnt] = i + cnt += 1 + for j in range(i + 1, boxes.shape[0]): + io_ratio[i][j] = iou(boxes[i], boxes[j]) + if keep[j]: + overlap = iou(boxes[i], boxes[j]) + keep[j] = 1 if overlap <= nms_threshold else 0 + else: + continue + + return selected_indices + + +class TestNMSOp(OpTest): + def setUp(self): + self.op_type = 'nms' + self.dtype = np.float64 + self.init_dtype_type() + boxes = np.random.rand(32, 4).astype(self.dtype) + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + self.inputs = {'Boxes': boxes} + self.attrs = {'iou_threshold': 0.5} + out_py = nms(boxes, self.attrs['iou_threshold']) + self.outputs = {'KeepBoxesIdxs': out_py} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_ops_nms.py b/python/paddle/fluid/tests/unittests/test_ops_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bbe82d3581a0c93bd696ee27bfe4c0388ffc8e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_ops_nms.py @@ -0,0 +1,190 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from test_nms_op import nms + + +def _find(condition): + """ + Find the indices of elements saticfied the condition. + + Args: + condition(Tensor[N] or np.ndarray([N,])): Element should be bool type. + + Returns: + Tensor: Indices of True element. + """ + res = [] + for i in range(condition.shape[0]): + if condition[i]: + res.append(i) + return np.array(res) + + +def multiclass_nms(boxes, scores, category_idxs, iou_threshold, top_k): + mask = np.zeros_like(scores) + + for category_id in np.unique(category_idxs): + cur_category_boxes_idxs = _find(category_idxs == category_id) + cur_category_boxes = boxes[cur_category_boxes_idxs] + cur_category_scores = scores[cur_category_boxes_idxs] + cur_category_sorted_indices = np.argsort(-cur_category_scores) + cur_category_sorted_boxes = cur_category_boxes[ + cur_category_sorted_indices] + + cur_category_keep_boxes_sub_idxs = cur_category_sorted_indices[nms( + cur_category_sorted_boxes, iou_threshold)] + + mask[cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs]] = True + + keep_boxes_idxs = _find(mask == True) + topK_sub_indices = np.argsort(-scores[keep_boxes_idxs])[:top_k] + return keep_boxes_idxs[topK_sub_indices] + + +def gen_args(num_boxes, dtype): + boxes = np.random.rand(num_boxes, 4).astype(dtype) + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + scores = np.random.rand(num_boxes).astype(dtype) + + categories = [0, 1, 2, 3] + category_idxs = np.random.choice(categories, num_boxes) + + return boxes, scores, category_idxs, categories + + +class TestOpsNMS(unittest.TestCase): + def setUp(self): + self.num_boxes = 64 + self.threshold = 0.5 + self.topk = 20 + self.dtypes = ['float32'] + self.devices = ['cpu'] + if paddle.is_compiled_with_cuda(): + self.devices.append('gpu') + + def test_nms(self): + for device in self.devices: + for dtype in self.dtypes: + boxes, scores, category_idxs, categories = gen_args( + self.num_boxes, dtype) + paddle.set_device(device) + out = paddle.vision.ops.nms( + paddle.to_tensor(boxes), self.threshold, + paddle.to_tensor(scores)) + out = paddle.vision.ops.nms( + paddle.to_tensor(boxes), self.threshold) + out_py = nms(boxes, self.threshold) + + self.assertTrue( + np.array_equal(out.numpy(), out_py), + "paddle out: {}\n py out: {}\n".format(out, out_py)) + + def test_multiclass_nms_dynamic(self): + for device in self.devices: + for dtype in self.dtypes: + boxes, scores, category_idxs, categories = gen_args( + self.num_boxes, dtype) + paddle.set_device(device) + out = paddle.vision.ops.nms( + paddle.to_tensor(boxes), self.threshold, + paddle.to_tensor(scores), + paddle.to_tensor(category_idxs), categories, self.topk) + out_py = multiclass_nms(boxes, scores, category_idxs, + self.threshold, self.topk) + + self.assertTrue( + np.array_equal(out.numpy(), out_py), + "paddle out: {}\n py out: {}\n".format(out, out_py)) + + def test_multiclass_nms_static(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.enable_static() + boxes, scores, category_idxs, categories = gen_args( + self.num_boxes, dtype) + boxes_static = paddle.static.data( + shape=boxes.shape, dtype=boxes.dtype, name="boxes") + scores_static = paddle.static.data( + shape=scores.shape, dtype=scores.dtype, name="scores") + category_idxs_static = paddle.static.data( + shape=category_idxs.shape, + dtype=category_idxs.dtype, + name="category_idxs") + out = paddle.vision.ops.nms(boxes_static, self.threshold, + scores_static, category_idxs_static, + categories, self.topk) + place = paddle.CPUPlace() + if device == 'gpu': + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + out = exe.run(paddle.static.default_main_program(), + feed={ + 'boxes': boxes, + 'scores': scores, + 'category_idxs': category_idxs + }, + fetch_list=[out]) + paddle.disable_static() + out_py = multiclass_nms(boxes, scores, category_idxs, + self.threshold, self.topk) + out = np.array(out) + out = np.squeeze(out) + self.assertTrue( + np.array_equal(out, out_py), + "paddle out: {}\n py out: {}\n".format(out, out_py)) + + def test_multiclass_nms_dynamic_to_static(self): + for device in self.devices: + for dtype in self.dtypes: + paddle.set_device(device) + + def fun(x): + scores = np.arange(0, 64).astype('float32') + categories = np.array([0, 1, 2, 3]) + category_idxs = categories.repeat(16) + out = paddle.vision.ops.nms(x, 0.1, + paddle.to_tensor(scores), + paddle.to_tensor(category_idxs), + categories, 10) + return out + + path = "./net" + boxes = np.random.rand(64, 4).astype('float32') + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + origin = fun(paddle.to_tensor(boxes)) + paddle.jit.save( + fun, + path, + input_spec=[ + paddle.static.InputSpec( + shape=[None, 4], dtype='float32', name='x') + ], ) + load_func = paddle.jit.load(path) + res = load_func(paddle.to_tensor(boxes)) + self.assertTrue( + np.array_equal(origin, res), + "origin out: {}\n inference model out: {}\n".format(origin, + res)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/vision/ops.py b/python/paddle/vision/ops.py index b510b7c8bdfe8f7855c1ebf97fc2464a9b8ac583..7797909e3b52c623547032cd936116004b1d4372 100644 --- a/python/paddle/vision/ops.py +++ b/python/paddle/vision/ops.py @@ -36,6 +36,7 @@ __all__ = [ #noqa 'PSRoIPool', 'roi_align', 'RoIAlign', + 'nms', ] @@ -1357,3 +1358,151 @@ class ConvNormActivation(Sequential): if activation_layer is not None: layers.append(activation_layer()) super().__init__(*layers) + + +def nms(boxes, + iou_threshold=0.3, + scores=None, + category_idxs=None, + categories=None, + top_k=None): + r""" + This operator implements non-maximum suppression. Non-maximum suppression (NMS) + is used to select one bounding box out of many overlapping bounding boxes in object detection. + Boxes with IoU > iou_threshold will be considered as overlapping boxes, + just one with highest score can be kept. Here IoU is Intersection Over Union, + which can be computed by: + + .. math:: + + IoU = \frac{intersection\_area(box1, box2)}{union\_area(box1, box2)} + + If scores are provided, input boxes will be sorted by their scores firstly. + If category_idxs and categories are provided, NMS will be performed with a batched style, + which means NMS will be applied to each category respectively and results of each category + will be concated and sorted by scores. + If K is provided, only the first k elements will be returned. Otherwise, all box indices sorted by scores will be returned. + + Args: + boxes(Tensor): The input boxes data to be computed, it's a 2D-Tensor with + the shape of [num_boxes, 4] and boxes should be sorted by their + confidence scores. The data type is float32 or float64. + Given as [[x1, y1, x2, y2], …], (x1, y1) is the top left coordinates, + and (x2, y2) is the bottom right coordinates. + Their relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``. + iou_threshold(float32): IoU threshold for determine overlapping boxes. Default value: 0.3. + scores(Tensor, optional): Scores corresponding to boxes, it's a 1D-Tensor with + shape of [num_boxes]. The data type is float32 or float64. + category_idxs(Tensor, optional): Category indices corresponding to boxes. + it's a 1D-Tensor with shape of [num_boxes]. The data type is int64. + categories(List, optional): A list of unique id of all categories. The data type is int64. + top_k(int64, optional): The top K boxes who has higher score and kept by NMS preds to + consider. top_k should be smaller equal than num_boxes. + + Returns: + Tensor: 1D-Tensor with the shape of [num_boxes]. Indices of boxes kept by NMS. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + boxes = np.random.rand(4, 4).astype('float32') + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + # [[0.06287421 0.5809351 0.3443958 0.8713329 ] + # [0.0749094 0.9713205 0.99241287 1.2799143 ] + # [0.46246734 0.6753201 1.346266 1.3821303 ] + # [0.8984796 0.5619834 1.1254641 1.0201943 ]] + + out = paddle.vision.ops.nms(paddle.to_tensor(boxes), 0.1) + # [0, 1, 3, 0] + + scores = np.random.rand(4).astype('float32') + # [0.98015213 0.3156527 0.8199343 0.874901 ] + + categories = [0, 1, 2, 3] + category_idxs = np.random.choice(categories, 4) + # [2 0 0 3] + + out = paddle.vision.ops.nms(paddle.to_tensor(boxes), + 0.1, + paddle.to_tensor(scores), + paddle.to_tensor(category_idxs), + categories, + 4) + # [0, 3, 2] + """ + + def _nms(boxes, iou_threshold): + if _non_static_mode(): + return _C_ops.nms(boxes, 'iou_threshold', iou_threshold) + + helper = LayerHelper('nms', **locals()) + out = helper.create_variable_for_type_inference('int64') + helper.append_op( + type='nms', + inputs={'Boxes': boxes}, + outputs={'KeepBoxesIdxs': out}, + attrs={'iou_threshold': iou_threshold}) + return out + + if scores is None: + return _nms(boxes, iou_threshold) + + import paddle + if category_idxs is None: + sorted_global_indices = paddle.argsort(scores, descending=True) + return _nms(boxes[sorted_global_indices], iou_threshold) + + if top_k is not None: + assert top_k <= scores.shape[ + 0], "top_k should be smaller equal than the number of boxes" + assert categories is not None, "if category_idxs is given, categories which is a list of unique id of all categories is necessary" + + mask = paddle.zeros_like(scores, dtype=paddle.int32) + + for category_id in categories: + cur_category_boxes_idxs = paddle.where(category_idxs == category_id)[0] + shape = cur_category_boxes_idxs.shape[0] + cur_category_boxes_idxs = paddle.reshape(cur_category_boxes_idxs, + [shape]) + if shape == 0: + continue + elif shape == 1: + mask[cur_category_boxes_idxs] = 1 + continue + cur_category_boxes = boxes[cur_category_boxes_idxs] + cur_category_scores = scores[cur_category_boxes_idxs] + cur_category_sorted_indices = paddle.argsort( + cur_category_scores, descending=True) + cur_category_sorted_boxes = cur_category_boxes[ + cur_category_sorted_indices] + + cur_category_keep_boxes_sub_idxs = cur_category_sorted_indices[_nms( + cur_category_sorted_boxes, iou_threshold)] + + updates = paddle.ones_like( + cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs], + dtype=paddle.int32) + mask = paddle.scatter( + mask, + cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs], + updates, + overwrite=True) + keep_boxes_idxs = paddle.where(mask)[0] + shape = keep_boxes_idxs.shape[0] + keep_boxes_idxs = paddle.reshape(keep_boxes_idxs, [shape]) + sorted_sub_indices = paddle.argsort( + scores[keep_boxes_idxs], descending=True) + + if top_k is None: + return keep_boxes_idxs[sorted_sub_indices] + + if _non_static_mode(): + top_k = shape if shape < top_k else top_k + _, topk_sub_indices = paddle.topk(scores[keep_boxes_idxs], top_k) + return keep_boxes_idxs[topk_sub_indices] + + return keep_boxes_idxs[sorted_sub_indices][:top_k] diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 365047f7e8382afa1646df2d4ff491471fa829c2..f907d51e4d038c9eb2bdd2bc59f1444b038b50d0 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -349,6 +349,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_nearest_interp_v2_op', 'test_network_with_dtype', 'test_nll_loss', + 'test_nms_op', 'test_nn_functional_embedding_static', 'test_nn_functional_hot_op', 'test_nonzero_api',