提交 53788640 编写于 作者: D dangqingqing

Fix the output order and add more unit test cases.

上级 35dec3d7
......@@ -201,7 +201,7 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
}
}
// Keep top k results per image.
std::sort(score_index_pairs.begin(), score_index_pairs.end(),
std::stable_sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<std::pair<int, int>>);
score_index_pairs.resize(keep_top_k);
......@@ -269,7 +269,8 @@ class MulticlassNMSKernel : public framework::OpKernel<T> {
int num_kept = batch_starts.back();
if (num_kept == 0) {
outs->Resize({0, 0});
T* od = outs->mutable_data<T>({1}, ctx.GetPlace());
od[0] = -1;
} else {
outs->mutable_data<T>({num_kept, kOutputDim}, ctx.GetPlace());
for (int64_t i = 0; i < batch_size; ++i) {
......@@ -349,11 +350,16 @@ is larger than -1. Then this operator pruns away boxes that have high IOU
(intersection over union) overlap with already selected boxes by adaptive
threshold NMS based on parameters of nms_threshold and nms_eta.
Aftern NMS step, only at most keep_top_k number of total bboxes are to be kept
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
per image if keep_top_k is larger than -1.
This operator support multi-class and batched inputs. It applying NMS
independently for each class.
independently for each class. The outputs is a 2-D LoDTenosr, for each
image, the offsets in first dimension of LoDTensor are called LoD, the number
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
means there is no detected bbox for this image. If there is no detected boxes
for all images, all the elements in LoD are 0, and the Out only contains one
value which is -1.
)DOC");
}
};
......
......@@ -56,8 +56,12 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
Args:
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
scores: (tensor) The class predscores for the img, Shape:[num_priors].
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
top_k: (int) The Maximum number of box preds to consider.
score_threshold: (float) The confidence thresh for filtering low
confidence boxes.
nms_threshold: (float) The overlap thresh for suppressing unnecessary
boxes.
top_k: (int) The maximum number of box preds to consider.
eta: (float) The parameter for adaptive NMS.
Return:
The indices of the kept boxes with respect to num_priors.
"""
......@@ -67,7 +71,7 @@ def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
selected_indices = selected_indices.flatten()
all_scores = all_scores[selected_indices]
sorted_indices = np.argsort(-all_scores, axis=0)
sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
sorted_scores = all_scores[sorted_indices]
if top_k > -1 and top_k < sorted_indices.shape[0]:
sorted_indices = sorted_indices[:top_k]
......@@ -97,29 +101,33 @@ def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
class_num = scores.shape[0]
priorbox_num = scores.shape[1]
selected_indices = []
selected_indices = {}
num_det = 0
for c in range(class_num):
if c == background: continue
indices = nms(boxes, scores[c], score_threshold, nms_threshold,
nms_top_k)
for idx in indices:
selected_indices.append((c, idx))
selected_indices[c] = indices
num_det += len(indices)
if keep_top_k > -1 and num_det > keep_top_k:
score_index = []
for c, idx in selected_indices:
for c, indices in selected_indices.iteritems():
for idx in indices:
score_index.append((scores[c][idx], c, idx))
sorted_score_index = sorted(
score_index, key=lambda tup: tup[0], reverse=True)
sorted_score_index = sorted_score_index[:keep_top_k]
selected_indices = []
selected_indices = {}
for _, c, _ in sorted_score_index:
selected_indices[c] = []
for s, c, idx in sorted_score_index:
selected_indices.append((c, idx))
selected_indices[c].append(idx)
num_det = keep_top_k
return selected_indices
return selected_indices, num_det
def batched_multiclass_nms(boxes, scores, background, score_threshold,
......@@ -129,28 +137,36 @@ def batched_multiclass_nms(boxes, scores, background, score_threshold,
det_outs = []
lod = [0]
for n in range(batch_size):
nmsed_outs = multiclass_nms(boxes, scores[n], background,
score_threshold, nms_threshold, nms_top_k,
keep_top_k)
lod.append(lod[-1] + len(nmsed_outs))
if len(nmsed_outs) == 0: continue
for c, idx in nmsed_outs:
nmsed_outs, nmsed_num = multiclass_nms(boxes, scores[n], background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
lod.append(lod[-1] + nmsed_num)
if nmsed_num == 0: continue
for c, indices in nmsed_outs.iteritems():
for idx in indices:
xmin, ymin, xmax, ymax = boxes[idx][:]
det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax])
return det_outs, lod
class TestMulticlassNMSOp(OpTest):
def set_argument(self):
self.score_threshold = 0.01
def setUp(self):
self.set_argument()
N = 7
M = 1240
M = 1200
C = 21
BOX_SIZE = 4
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
score_threshold = 0.01
score_threshold = self.score_threshold
scores = np.random.random((N * M, C)).astype('float32')
......@@ -165,11 +181,12 @@ class TestMulticlassNMSOp(OpTest):
boxes = np.random.random((M, BOX_SIZE)).astype('float32')
boxes[:, 0:2] = boxes[:, 0:2] * 0.5
boxes[:, 2:4] = boxes[:, 0:2] * 0.5 + 0.5
boxes[:, 2:4] = boxes[:, 2:4] * 0.5 + 0.5
nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
score_threshold, nms_threshold,
nms_top_k, keep_top_k)
nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
nmsed_outs = np.array(nmsed_outs).astype('float32')
self.op_type = 'multiclass_nms'
......@@ -188,6 +205,13 @@ class TestMulticlassNMSOp(OpTest):
self.check_output()
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
class TestIOU(unittest.TestCase):
def test_iou(self):
box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册