提交 53788640 编写于 作者: D dangqingqing

Fix the output order and add more unit test cases.

上级 35dec3d7
...@@ -201,8 +201,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -201,8 +201,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
} }
} }
// Keep top k results per image. // Keep top k results per image.
std::sort(score_index_pairs.begin(), score_index_pairs.end(), std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>); SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k); score_index_pairs.resize(keep_top_k);
// Store the new indices. // Store the new indices.
...@@ -269,7 +269,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> { ...@@ -269,7 +269,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
int num_kept = batch_starts.back(); int num_kept = batch_starts.back();
if (num_kept == 0) { if (num_kept == 0) {
outs->Resize({0, 0}); T* od = outs->mutable_data<T>({1}, ctx.GetPlace());
od[0] = -1;
} else { } else {
outs->mutable_data<T>({num_kept, kOutputDim}, ctx.GetPlace()); outs->mutable_data<T>({num_kept, kOutputDim}, ctx.GetPlace());
for (int64_t i = 0; i < batch_size; ++i) { for (int64_t i = 0; i < batch_size; ++i) {
...@@ -349,11 +350,16 @@ is larger than -1. Then this operator pruns away boxes that have high IOU ...@@ -349,11 +350,16 @@ is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive (intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta. threshold NMS based on parameters of nms_threshold and nms_eta.
Aftern NMS step, only at most keep_top_k number of total bboxes are to be kept Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1. per image if keep_top_k is larger than -1.
This operator support multi-class and batched inputs. It applying NMS This operator support multi-class and batched inputs. It applying NMS
independently for each class. independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bbox for this image. If there is no detected boxes
for all images, all the elements in LoD are 0, and the Out only contains one
value which is -1.
)DOC"); )DOC");
} }
}; };
......
...@@ -56,8 +56,12 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): ...@@ -56,8 +56,12 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
Args: Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors]. scores: (tensor) The class predscores for the img, Shape:[num_priors].
overlap: (float) The overlap thresh for suppressing unnecessary boxes. score_threshold: (float) The confidence thresh for filtering low
top_k: (int) The Maximum number of box preds to consider. confidence boxes.
nms_threshold: (float) The overlap thresh for suppressing unnecessary
boxes.
top_k: (int) The maximum number of box preds to consider.
eta: (float) The parameter for adaptive NMS.
Return: Return:
The indices of the kept boxes with respect to num_priors. The indices of the kept boxes with respect to num_priors.
""" """
...@@ -67,7 +71,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0): ...@@ -67,7 +71,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
selected_indices = selected_indices.flatten() selected_indices = selected_indices.flatten()
all_scores = all_scores[selected_indices] all_scores = all_scores[selected_indices]
sorted_indices = np.argsort(-all_scores, axis=0) sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
sorted_scores = all_scores[sorted_indices] sorted_scores = all_scores[sorted_indices]
if top_k > -1 and top_k < sorted_indices.shape[0]: if top_k > -1 and top_k < sorted_indices.shape[0]:
sorted_indices = sorted_indices[:top_k] sorted_indices = sorted_indices[:top_k]
...@@ -97,29 +101,33 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold, ...@@ -97,29 +101,33 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
class_num = scores.shape[0] class_num = scores.shape[0]
priorbox_num = scores.shape[1] priorbox_num = scores.shape[1]
selected_indices = [] selected_indices = {}
num_det = 0 num_det = 0
for c in range(class_num): for c in range(class_num):
if c == background: continue if c == background: continue
indices = nms(boxes, scores[c], score_threshold, nms_threshold, indices = nms(boxes, scores[c], score_threshold, nms_threshold,
nms_top_k) nms_top_k)
for idx in indices: selected_indices[c] = indices
selected_indices.append((c, idx))
num_det += len(indices) num_det += len(indices)
if keep_top_k > -1 and num_det > keep_top_k: if keep_top_k > -1 and num_det > keep_top_k:
score_index = [] score_index = []
for c, idx in selected_indices: for c, indices in selected_indices.iteritems():
score_index.append((scores[c][idx], c, idx)) for idx in indices:
score_index.append((scores[c][idx], c, idx))
sorted_score_index = sorted( sorted_score_index = sorted(
score_index, key=lambda tup: tup[0], reverse=True) score_index, key=lambda tup: tup[0], reverse=True)
sorted_score_index = sorted_score_index[:keep_top_k] sorted_score_index = sorted_score_index[:keep_top_k]
selected_indices = [] selected_indices = {}
for _, c, _ in sorted_score_index:
selected_indices[c] = []
for s, c, idx in sorted_score_index: for s, c, idx in sorted_score_index:
selected_indices.append((c, idx)) selected_indices[c].append(idx)
num_det = keep_top_k
return selected_indices return selected_indices, num_det
def batched_multiclass_nms(boxes, scores, background, score_threshold, def batched_multiclass_nms(boxes, scores, background, score_threshold,
...@@ -129,28 +137,36 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold, ...@@ -129,28 +137,36 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold,
det_outs = [] det_outs = []
lod = [0] lod = [0]
for n in range(batch_size): for n in range(batch_size):
nmsed_outs = multiclass_nms(boxes, scores[n], background, nmsed_outs, nmsed_num = multiclass_nms(boxes, scores[n], background,
score_threshold, nms_threshold, nms_top_k, score_threshold, nms_threshold,
keep_top_k) nms_top_k, keep_top_k)
lod.append(lod[-1] + len(nmsed_outs)) lod.append(lod[-1] + nmsed_num)
if len(nmsed_outs) == 0: continue if nmsed_num == 0: continue
for c, idx in nmsed_outs:
xmin, ymin, xmax, ymax = boxes[idx][:] for c, indices in nmsed_outs.iteritems():
det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax]) for idx in indices:
xmin, ymin, xmax, ymax = boxes[idx][:]
det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax])
return det_outs, lod return det_outs, lod
class TestMulticlassNMSOp(OpTest): class TestMulticlassNMSOp(OpTest):
def set_argument(self):
self.score_threshold = 0.01
def setUp(self): def setUp(self):
self.set_argument()
N = 7 N = 7
M = 1240 M = 1200
C = 21 C = 21
BOX_SIZE = 4 BOX_SIZE = 4
background = 0 background = 0
nms_threshold = 0.3 nms_threshold = 0.3
nms_top_k = 400 nms_top_k = 400
keep_top_k = 200 keep_top_k = 200
score_threshold = 0.01 score_threshold = self.score_threshold
scores = np.random.random((N * M, C)).astype('float32') scores = np.random.random((N * M, C)).astype('float32')
...@@ -165,11 +181,12 @@ class TestMulticlassNMSOp(OpTest): ...@@ -165,11 +181,12 @@ class TestMulticlassNMSOp(OpTest):
boxes = np.random.random((M, BOX_SIZE)).astype('float32') boxes = np.random.random((M, BOX_SIZE)).astype('float32')
boxes[:, 0:2] = boxes[:, 0:2] * 0.5 boxes[:, 0:2] = boxes[:, 0:2] * 0.5
boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5 boxes[:, 2:4] = boxes[:, 2:4] * 0.5 + 0.5
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background, nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold, score_threshold, nms_threshold,
nms_top_k, keep_top_k) nms_top_k, keep_top_k)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32') nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'multiclass_nms' self.op_type = 'multiclass_nms'
...@@ -188,6 +205,13 @@ class TestMulticlassNMSOp(OpTest): ...@@ -188,6 +205,13 @@ class TestMulticlassNMSOp(OpTest):
self.check_output() self.check_output()
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
class TestIOU(unittest.TestCase): class TestIOU(unittest.TestCase):
def test_iou(self): def test_iou(self):
box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32') box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册