未验证 提交 7554f428 编写于 作者: R RichardWooSJTU 提交者: GitHub

Add nms op and batched_nms api (#40962)

* add nms op and batched_nms api
上级 510347f9
......@@ -66,6 +66,7 @@ detection_library(yolo_box_op SRCS yolo_box_op.cc)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
detection_library(nms_op SRCS nms_op.cc nms_op.cu)
if(WITH_GPU OR WITH_ROCM)
set(TMPDEPS memory)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detection/nms_op.h"
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class NMSOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Boxes",
"(Tensor) "
"Boxes is a Tensor with shape [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2] "
"the relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``.");
AddOutput("KeepBoxesIdxs",
"(Tensor) "
"KeepBoxesIdxs is a Tensor with shape [N] ");
AddAttr<float>(
"iou_threshold",
"iou_threshold is a threshold value used to compress similar boxes "
"boxes with IoU > iou_threshold will be considered as overlapping "
"and just one of them can be kept.")
.SetDefault(1.0f)
.AddCustomChecker([](const float& iou_threshold) {
PADDLE_ENFORCE_LE(iou_threshold, 1.0f,
platform::errors::InvalidArgument(
"iou_threshold should less equal than 1.0 "
"but got %f",
iou_threshold));
PADDLE_ENFORCE_GE(iou_threshold, 0.0f,
platform::errors::InvalidArgument(
"iou_threshold should greater equal than 0.0 "
"but got %f",
iou_threshold));
});
AddComment(R"DOC(
NMS Operator.
This Operator is used to perform Non-Maximum Compress for input boxes.
Indices of boxes kept by NMS will be sorted by scores and output.
)DOC");
}
};
class NMSOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS");
OP_INOUT_CHECK(ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs",
"NMS");
auto boxes_dim = ctx->GetInputDim("Boxes");
PADDLE_ENFORCE_EQ(boxes_dim.size(), 2,
platform::errors::InvalidArgument(
"The Input Boxes must be 2-dimention "
"whose shape must be [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2]. "));
auto num_boxes = boxes_dim[0];
ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace());
}
};
template <typename T>
static void NMS(const T* boxes_data, int64_t* output_data, float threshold,
int64_t num_boxes) {
auto num_masks = CeilDivide(num_boxes, 64);
std::vector<uint64_t> masks(num_masks, 0);
for (int64_t i = 0; i < num_boxes; ++i) {
if (masks[i / 64] & 1ULL << (i % 64)) continue;
T box_1[4];
for (int k = 0; k < 4; ++k) {
box_1[k] = boxes_data[i * 4 + k];
}
for (int64_t j = i + 1; j < num_boxes; ++j) {
if (masks[j / 64] & 1ULL << (j % 64)) continue;
T box_2[4];
for (int k = 0; k < 4; ++k) {
box_2[k] = boxes_data[j * 4 + k];
}
bool is_overlap = CalculateIoU<T>(box_1, box_2, threshold);
if (is_overlap) {
masks[j / 64] |= 1ULL << (j % 64);
}
}
}
int64_t output_data_idx = 0;
for (int64_t i = 0; i < num_boxes; ++i) {
if (masks[i / 64] & 1ULL << (i % 64)) continue;
output_data[output_data_idx++] = i;
}
for (; output_data_idx < num_boxes; ++output_data_idx) {
output_data[output_data_idx] = 0;
}
}
template <typename T>
class NMSKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* boxes = context.Input<Tensor>("Boxes");
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
auto threshold = context.template Attr<float>("iou_threshold");
NMS<T>(boxes->data<T>(), output_data, threshold, boxes->dims()[0]);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
nms, ops::NMSOp, ops::NMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel<float>, ops::NMSKernel<double>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/detection/nms_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
static const int64_t threadsPerBlock = sizeof(int64_t) * 8;
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
static __global__ void NMS(const T* boxes_data, float threshold,
int64_t num_boxes, uint64_t* masks) {
auto raw_start = blockIdx.y;
auto col_start = blockIdx.x;
if (raw_start > col_start) return;
const int raw_last_storage =
min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock);
const int col_last_storage =
min(num_boxes - col_start * threadsPerBlock, threadsPerBlock);
if (threadIdx.x < raw_last_storage) {
uint64_t mask = 0;
auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x;
const T* current_box = boxes_data + current_box_idx * 4;
for (int i = 0; i < col_last_storage; ++i) {
const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4;
if (CalculateIoU<T>(current_box, target_box, threshold)) {
mask |= 1ULL << i;
}
}
const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
masks[current_box_idx * blocks_per_line + col_start] = mask;
}
}
template <typename T>
class NMSCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* boxes = context.Input<Tensor>("Boxes");
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
auto* output_data = output->mutable_data<int64_t>(context.GetPlace());
auto threshold = context.template Attr<float>("iou_threshold");
const int64_t num_boxes = boxes->dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
dim3 block(threadsPerBlock);
dim3 grid(blocks_per_line, blocks_per_line);
auto mask_data =
memory::Alloc(context.cuda_device_context(),
num_boxes * blocks_per_line * sizeof(uint64_t));
uint64_t* mask_dev = reinterpret_cast<uint64_t*>(mask_data->ptr());
NMS<T><<<grid, block, 0, context.cuda_device_context().stream()>>>(
boxes->data<T>(), threshold, num_boxes, mask_dev);
std::vector<uint64_t> mask_host(num_boxes * blocks_per_line);
memory::Copy(platform::CPUPlace(), mask_host.data(), context.GetPlace(),
mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t),
context.cuda_device_context().stream());
std::vector<int64_t> remv(blocks_per_line);
std::vector<int64_t> keep_boxes_idxs(num_boxes);
int64_t* output_host = keep_boxes_idxs.data();
int64_t last_box_num = 0;
for (int64_t i = 0; i < num_boxes; ++i) {
auto remv_element_id = i / threadsPerBlock;
auto remv_bit_id = i % threadsPerBlock;
if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) {
output_host[last_box_num++] = i;
uint64_t* current_mask = mask_host.data() + i * blocks_per_line;
for (auto j = remv_element_id; j < blocks_per_line; ++j) {
remv[j] |= current_mask[j];
}
}
}
memory::Copy(context.GetPlace(), output_data, platform::CPUPlace(),
output_host, sizeof(int64_t) * num_boxes,
context.cuda_device_context().stream());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(nms, ops::NMSCudaKernel<float>,
ops::NMSCudaKernel<double>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) {
return (n + m - 1) / m;
}
template <typename T>
HOSTDEVICE inline bool CalculateIoU(const T* const box_1, const T* const box_2,
const float threshold) {
auto box_1_x0 = box_1[0], box_1_y0 = box_1[1];
auto box_1_x1 = box_1[2], box_1_y1 = box_1[3];
auto box_2_x0 = box_2[0], box_2_y0 = box_2[1];
auto box_2_x1 = box_2[2], box_2_y1 = box_2[3];
auto inter_box_x0 = box_1_x0 > box_2_x0 ? box_1_x0 : box_2_x0;
auto inter_box_y0 = box_1_y0 > box_2_y0 ? box_1_y0 : box_2_y0;
auto inter_box_x1 = box_1_x1 < box_2_x1 ? box_1_x1 : box_2_x1;
auto inter_box_y1 = box_1_y1 < box_2_y1 ? box_1_y1 : box_2_y1;
auto inter_width =
inter_box_x1 - inter_box_x0 > 0 ? inter_box_x1 - inter_box_x0 : 0;
auto inter_height =
inter_box_y1 - inter_box_y0 > 0 ? inter_box_y1 - inter_box_y0 : 0;
auto inter_area = inter_width * inter_height;
auto union_area = (box_1_x1 - box_1_x0) * (box_1_y1 - box_1_y0) +
(box_2_x1 - box_2_x0) * (box_2_y1 - box_2_y0) - inter_area;
return inter_area / union_area > threshold;
}
} // namespace operators
} // namespace paddle
......@@ -234,6 +234,7 @@ endif()
if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_complex_matmul)
LIST(REMOVE_ITEM TEST_OPS test_ops_nms)
endif()
LIST(REMOVE_ITEM TEST_OPS test_fleet_checkpoint)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
def iou(box_a, box_b):
"""Apply intersection-over-union overlap between box_a and box_b
"""
xmin_a = min(box_a[0], box_a[2])
ymin_a = min(box_a[1], box_a[3])
xmax_a = max(box_a[0], box_a[2])
ymax_a = max(box_a[1], box_a[3])
xmin_b = min(box_b[0], box_b[2])
ymin_b = min(box_b[1], box_b[3])
xmax_b = max(box_b[0], box_b[2])
ymax_b = max(box_b[1], box_b[3])
area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a)
area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b)
if area_a <= 0 and area_b <= 0:
return 0.0
xa = max(xmin_a, xmin_b)
ya = max(ymin_a, ymin_b)
xb = min(xmax_a, xmax_b)
yb = min(ymax_a, ymax_b)
inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)
iou_ratio = inter_area / (area_a + area_b - inter_area)
return iou_ratio
def nms(boxes, nms_threshold):
selected_indices = np.zeros(boxes.shape[0], dtype=np.int64)
keep = np.ones(boxes.shape[0], dtype=int)
io_ratio = np.ones((boxes.shape[0], boxes.shape[0]), dtype=np.float64)
cnt = 0
for i in range(boxes.shape[0]):
if keep[i] == 0:
continue
selected_indices[cnt] = i
cnt += 1
for j in range(i + 1, boxes.shape[0]):
io_ratio[i][j] = iou(boxes[i], boxes[j])
if keep[j]:
overlap = iou(boxes[i], boxes[j])
keep[j] = 1 if overlap <= nms_threshold else 0
else:
continue
return selected_indices
class TestNMSOp(OpTest):
def setUp(self):
self.op_type = 'nms'
self.dtype = np.float64
self.init_dtype_type()
boxes = np.random.rand(32, 4).astype(self.dtype)
boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
self.inputs = {'Boxes': boxes}
self.attrs = {'iou_threshold': 0.5}
out_py = nms(boxes, self.attrs['iou_threshold'])
self.outputs = {'KeepBoxesIdxs': out_py}
def init_dtype_type(self):
pass
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from test_nms_op import nms
def _find(condition):
"""
Find the indices of elements saticfied the condition.
Args:
condition(Tensor[N] or np.ndarray([N,])): Element should be bool type.
Returns:
Tensor: Indices of True element.
"""
res = []
for i in range(condition.shape[0]):
if condition[i]:
res.append(i)
return np.array(res)
def multiclass_nms(boxes, scores, category_idxs, iou_threshold, top_k):
mask = np.zeros_like(scores)
for category_id in np.unique(category_idxs):
cur_category_boxes_idxs = _find(category_idxs == category_id)
cur_category_boxes = boxes[cur_category_boxes_idxs]
cur_category_scores = scores[cur_category_boxes_idxs]
cur_category_sorted_indices = np.argsort(-cur_category_scores)
cur_category_sorted_boxes = cur_category_boxes[
cur_category_sorted_indices]
cur_category_keep_boxes_sub_idxs = cur_category_sorted_indices[nms(
cur_category_sorted_boxes, iou_threshold)]
mask[cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs]] = True
keep_boxes_idxs = _find(mask == True)
topK_sub_indices = np.argsort(-scores[keep_boxes_idxs])[:top_k]
return keep_boxes_idxs[topK_sub_indices]
def gen_args(num_boxes, dtype):
boxes = np.random.rand(num_boxes, 4).astype(dtype)
boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
scores = np.random.rand(num_boxes).astype(dtype)
categories = [0, 1, 2, 3]
category_idxs = np.random.choice(categories, num_boxes)
return boxes, scores, category_idxs, categories
class TestOpsNMS(unittest.TestCase):
def setUp(self):
self.num_boxes = 64
self.threshold = 0.5
self.topk = 20
self.dtypes = ['float32']
self.devices = ['cpu']
if paddle.is_compiled_with_cuda():
self.devices.append('gpu')
def test_nms(self):
for device in self.devices:
for dtype in self.dtypes:
boxes, scores, category_idxs, categories = gen_args(
self.num_boxes, dtype)
paddle.set_device(device)
out = paddle.vision.ops.nms(
paddle.to_tensor(boxes), self.threshold,
paddle.to_tensor(scores))
out = paddle.vision.ops.nms(
paddle.to_tensor(boxes), self.threshold)
out_py = nms(boxes, self.threshold)
self.assertTrue(
np.array_equal(out.numpy(), out_py),
"paddle out: {}\n py out: {}\n".format(out, out_py))
def test_multiclass_nms_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
boxes, scores, category_idxs, categories = gen_args(
self.num_boxes, dtype)
paddle.set_device(device)
out = paddle.vision.ops.nms(
paddle.to_tensor(boxes), self.threshold,
paddle.to_tensor(scores),
paddle.to_tensor(category_idxs), categories, self.topk)
out_py = multiclass_nms(boxes, scores, category_idxs,
self.threshold, self.topk)
self.assertTrue(
np.array_equal(out.numpy(), out_py),
"paddle out: {}\n py out: {}\n".format(out, out_py))
def test_multiclass_nms_static(self):
for device in self.devices:
for dtype in self.dtypes:
paddle.enable_static()
boxes, scores, category_idxs, categories = gen_args(
self.num_boxes, dtype)
boxes_static = paddle.static.data(
shape=boxes.shape, dtype=boxes.dtype, name="boxes")
scores_static = paddle.static.data(
shape=scores.shape, dtype=scores.dtype, name="scores")
category_idxs_static = paddle.static.data(
shape=category_idxs.shape,
dtype=category_idxs.dtype,
name="category_idxs")
out = paddle.vision.ops.nms(boxes_static, self.threshold,
scores_static, category_idxs_static,
categories, self.topk)
place = paddle.CPUPlace()
if device == 'gpu':
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
out = exe.run(paddle.static.default_main_program(),
feed={
'boxes': boxes,
'scores': scores,
'category_idxs': category_idxs
},
fetch_list=[out])
paddle.disable_static()
out_py = multiclass_nms(boxes, scores, category_idxs,
self.threshold, self.topk)
out = np.array(out)
out = np.squeeze(out)
self.assertTrue(
np.array_equal(out, out_py),
"paddle out: {}\n py out: {}\n".format(out, out_py))
def test_multiclass_nms_dynamic_to_static(self):
for device in self.devices:
for dtype in self.dtypes:
paddle.set_device(device)
def fun(x):
scores = np.arange(0, 64).astype('float32')
categories = np.array([0, 1, 2, 3])
category_idxs = categories.repeat(16)
out = paddle.vision.ops.nms(x, 0.1,
paddle.to_tensor(scores),
paddle.to_tensor(category_idxs),
categories, 10)
return out
path = "./net"
boxes = np.random.rand(64, 4).astype('float32')
boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
origin = fun(paddle.to_tensor(boxes))
paddle.jit.save(
fun,
path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 4], dtype='float32', name='x')
], )
load_func = paddle.jit.load(path)
res = load_func(paddle.to_tensor(boxes))
self.assertTrue(
np.array_equal(origin, res),
"origin out: {}\n inference model out: {}\n".format(origin,
res))
if __name__ == '__main__':
unittest.main()
......@@ -36,6 +36,7 @@ __all__ = [ #noqa
'PSRoIPool',
'roi_align',
'RoIAlign',
'nms',
]
......@@ -1357,3 +1358,151 @@ class ConvNormActivation(Sequential):
if activation_layer is not None:
layers.append(activation_layer())
super().__init__(*layers)
def nms(boxes,
iou_threshold=0.3,
scores=None,
category_idxs=None,
categories=None,
top_k=None):
r"""
This operator implements non-maximum suppression. Non-maximum suppression (NMS)
is used to select one bounding box out of many overlapping bounding boxes in object detection.
Boxes with IoU > iou_threshold will be considered as overlapping boxes,
just one with highest score can be kept. Here IoU is Intersection Over Union,
which can be computed by:
.. math::
IoU = \frac{intersection\_area(box1, box2)}{union\_area(box1, box2)}
If scores are provided, input boxes will be sorted by their scores firstly.
If category_idxs and categories are provided, NMS will be performed with a batched style,
which means NMS will be applied to each category respectively and results of each category
will be concated and sorted by scores.
If K is provided, only the first k elements will be returned. Otherwise, all box indices sorted by scores will be returned.
Args:
boxes(Tensor): The input boxes data to be computed, it's a 2D-Tensor with
the shape of [num_boxes, 4] and boxes should be sorted by their
confidence scores. The data type is float32 or float64.
Given as [[x1, y1, x2, y2], …], (x1, y1) is the top left coordinates,
and (x2, y2) is the bottom right coordinates.
Their relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``.
iou_threshold(float32): IoU threshold for determine overlapping boxes. Default value: 0.3.
scores(Tensor, optional): Scores corresponding to boxes, it's a 1D-Tensor with
shape of [num_boxes]. The data type is float32 or float64.
category_idxs(Tensor, optional): Category indices corresponding to boxes.
it's a 1D-Tensor with shape of [num_boxes]. The data type is int64.
categories(List, optional): A list of unique id of all categories. The data type is int64.
top_k(int64, optional): The top K boxes who has higher score and kept by NMS preds to
consider. top_k should be smaller equal than num_boxes.
Returns:
Tensor: 1D-Tensor with the shape of [num_boxes]. Indices of boxes kept by NMS.
Examples:
.. code-block:: python
import paddle
import numpy as np
boxes = np.random.rand(4, 4).astype('float32')
boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
# [[0.06287421 0.5809351 0.3443958 0.8713329 ]
# [0.0749094 0.9713205 0.99241287 1.2799143 ]
# [0.46246734 0.6753201 1.346266 1.3821303 ]
# [0.8984796 0.5619834 1.1254641 1.0201943 ]]
out = paddle.vision.ops.nms(paddle.to_tensor(boxes), 0.1)
# [0, 1, 3, 0]
scores = np.random.rand(4).astype('float32')
# [0.98015213 0.3156527 0.8199343 0.874901 ]
categories = [0, 1, 2, 3]
category_idxs = np.random.choice(categories, 4)
# [2 0 0 3]
out = paddle.vision.ops.nms(paddle.to_tensor(boxes),
0.1,
paddle.to_tensor(scores),
paddle.to_tensor(category_idxs),
categories,
4)
# [0, 3, 2]
"""
def _nms(boxes, iou_threshold):
if _non_static_mode():
return _C_ops.nms(boxes, 'iou_threshold', iou_threshold)
helper = LayerHelper('nms', **locals())
out = helper.create_variable_for_type_inference('int64')
helper.append_op(
type='nms',
inputs={'Boxes': boxes},
outputs={'KeepBoxesIdxs': out},
attrs={'iou_threshold': iou_threshold})
return out
if scores is None:
return _nms(boxes, iou_threshold)
import paddle
if category_idxs is None:
sorted_global_indices = paddle.argsort(scores, descending=True)
return _nms(boxes[sorted_global_indices], iou_threshold)
if top_k is not None:
assert top_k <= scores.shape[
0], "top_k should be smaller equal than the number of boxes"
assert categories is not None, "if category_idxs is given, categories which is a list of unique id of all categories is necessary"
mask = paddle.zeros_like(scores, dtype=paddle.int32)
for category_id in categories:
cur_category_boxes_idxs = paddle.where(category_idxs == category_id)[0]
shape = cur_category_boxes_idxs.shape[0]
cur_category_boxes_idxs = paddle.reshape(cur_category_boxes_idxs,
[shape])
if shape == 0:
continue
elif shape == 1:
mask[cur_category_boxes_idxs] = 1
continue
cur_category_boxes = boxes[cur_category_boxes_idxs]
cur_category_scores = scores[cur_category_boxes_idxs]
cur_category_sorted_indices = paddle.argsort(
cur_category_scores, descending=True)
cur_category_sorted_boxes = cur_category_boxes[
cur_category_sorted_indices]
cur_category_keep_boxes_sub_idxs = cur_category_sorted_indices[_nms(
cur_category_sorted_boxes, iou_threshold)]
updates = paddle.ones_like(
cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs],
dtype=paddle.int32)
mask = paddle.scatter(
mask,
cur_category_boxes_idxs[cur_category_keep_boxes_sub_idxs],
updates,
overwrite=True)
keep_boxes_idxs = paddle.where(mask)[0]
shape = keep_boxes_idxs.shape[0]
keep_boxes_idxs = paddle.reshape(keep_boxes_idxs, [shape])
sorted_sub_indices = paddle.argsort(
scores[keep_boxes_idxs], descending=True)
if top_k is None:
return keep_boxes_idxs[sorted_sub_indices]
if _non_static_mode():
top_k = shape if shape < top_k else top_k
_, topk_sub_indices = paddle.topk(scores[keep_boxes_idxs], top_k)
return keep_boxes_idxs[topk_sub_indices]
return keep_boxes_idxs[sorted_sub_indices][:top_k]
......@@ -349,6 +349,7 @@ STATIC_MODE_TESTING_LIST = [
'test_nearest_interp_v2_op',
'test_network_with_dtype',
'test_nll_loss',
'test_nms_op',
'test_nn_functional_embedding_static',
'test_nn_functional_hot_op',
'test_nonzero_api',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册