test_multiclass_nms_op.py 11.8 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14 15

from __future__ import print_function
16 17 18
import unittest
import numpy as np
import copy
19
from op_test import OpTest
20 21


J
jerrywgz 已提交
22
def iou(box_a, box_b, norm):
23 24 25 26 27 28 29 30 31 32 33 34
    """Apply intersection-over-union overlap between box_a and box_b
    """
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

J
jerrywgz 已提交
35 36 37 38
    area_a = (ymax_a - ymin_a + (norm == False)) * (xmax_a - xmin_a +
                                                    (norm == False))
    area_b = (ymax_b - ymin_b + (norm == False)) * (xmax_b - xmin_b +
                                                    (norm == False))
39 40 41 42 43 44 45 46
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

J
jerrywgz 已提交
47 48
    inter_area = max(xb - xa + (norm == False),
                     0.0) * max(yb - ya + (norm == False), 0.0)
49 50 51 52 53 54

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


J
jerrywgz 已提交
55 56 57 58 59 60 61
def nms(boxes,
        scores,
        score_threshold,
        nms_threshold,
        top_k=200,
        normalized=True,
        eta=1.0):
62 63 64 65 66
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
67 68 69 70 71 72
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
73 74 75 76 77 78 79 80 81
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

82
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
83
    sorted_scores = all_scores[sorted_indices]
D
dangqingqing 已提交
84
    if top_k > -1 and top_k < sorted_indices.shape[0]:
85 86 87 88 89 90 91 92 93 94 95
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
J
jerrywgz 已提交
96
                overlap = iou(boxes[idx], boxes[kept_idx], normalized)
D
dangqingqing 已提交
97
                keep = True if overlap <= adaptive_threshold else False
98 99 100 101 102 103 104 105 106 107
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
J
jerrywgz 已提交
108 109 110 111 112 113 114
                   nms_top_k, keep_top_k, normalized, shared):
    if shared:
        class_num = scores.shape[0]
        priorbox_num = scores.shape[1]
    else:
        box_num = scores.shape[0]
        class_num = scores.shape[1]
115

116
    selected_indices = {}
117 118 119
    num_det = 0
    for c in range(class_num):
        if c == background: continue
J
jerrywgz 已提交
120 121 122 123 124 125
        if shared:
            indices = nms(boxes, scores[c], score_threshold, nms_threshold,
                          nms_top_k, normalized)
        else:
            indices = nms(boxes[:, c, :], scores[:, c], score_threshold,
                          nms_threshold, nms_top_k, normalized)
126
        selected_indices[c] = indices
127 128 129 130
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
131
        for c, indices in selected_indices.items():
132
            for idx in indices:
J
jerrywgz 已提交
133 134 135 136
                if shared:
                    score_index.append((scores[c][idx], c, idx))
                else:
                    score_index.append((scores[idx][c], c, idx))
137 138 139 140

        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True)
        sorted_score_index = sorted_score_index[:keep_top_k]
141 142 143 144
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
145
        for s, c, idx in sorted_score_index:
146
            selected_indices[c].append(idx)
J
jerrywgz 已提交
147 148 149
        if not shared:
            for labels in selected_indices:
                selected_indices[labels].sort()
150
        num_det = keep_top_k
151

152
    return selected_indices, num_det
153 154


J
jerrywgz 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
def lod_multiclass_nms(boxes, scores, background, score_threshold,
                       nms_threshold, nms_top_k, keep_top_k, box_lod,
                       normalized):
    det_outs = []
    lod = []
    head = 0
    for n in range(len(box_lod[0])):
        box = boxes[head:head + box_lod[0][n]]
        score = scores[head:head + box_lod[0][n]]
        head = head + box_lod[0][n]
        nmsed_outs, nmsed_num = multiclass_nms(
            box,
            score,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=False)
        if nmsed_num == 0:
            continue
        lod.append(nmsed_num)
178
        tmp_det_out = []
J
jerrywgz 已提交
179 180 181
        for c, indices in nmsed_outs.items():
            for idx in indices:
                xmin, ymin, xmax, ymax = box[idx, c, :]
182 183 184 185
                tmp_det_out.append([c, score[idx][c], xmin, ymin, xmax, ymax])
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False)
        det_outs.extend(sorted_det_out)
186 187
    if len(lod) == 0:
        lod.append(1)
J
jerrywgz 已提交
188 189 190 191 192 193 194 195 196 197 198 199

    return det_outs, lod


def batched_multiclass_nms(boxes,
                           scores,
                           background,
                           score_threshold,
                           nms_threshold,
                           nms_top_k,
                           keep_top_k,
                           normalized=True):
200 201 202
    batch_size = scores.shape[0]

    det_outs = []
203
    lod = []
204
    for n in range(batch_size):
J
jerrywgz 已提交
205 206 207 208 209 210 211 212 213 214 215 216
        nmsed_outs, nmsed_num = multiclass_nms(
            boxes[n],
            scores[n],
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=True)
        if nmsed_num == 0:
            continue
217

J
jerrywgz 已提交
218
        lod.append(nmsed_num)
219
        tmp_det_out = []
220
        for c, indices in nmsed_outs.items():
221
            for idx in indices:
222
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
223 224 225 226 227
                tmp_det_out.append(
                    [c, scores[n][c][idx], xmin, ymin, xmax, ymax])
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False)
        det_outs.extend(sorted_det_out)
228 229
    if len(lod) == 0:
        lod += [1]
230 231 232 233
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
234 235 236
    def set_argument(self):
        self.score_threshold = 0.01

237
    def setUp(self):
238
        self.set_argument()
239
        N = 7
240
        M = 1200
241 242 243 244 245 246
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
247
        score_threshold = self.score_threshold
248

D
dangqingqing 已提交
249 250 251 252 253 254 255 256 257 258 259
        scores = np.random.random((N * M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

260 261 262
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
263 264 265 266

        nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
                                                 score_threshold, nms_threshold,
                                                 nms_top_k, keep_top_k)
267
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
D
dangqingqing 已提交
268 269 270
        nmsed_outs = np.array(nmsed_outs).astype('float32')

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
271
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
272
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
273 274 275 276 277 278 279
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
J
jerrywgz 已提交
280
            'normalized': True,
D
dangqingqing 已提交
281
        }
282 283 284 285 286

    def test_check_output(self):
        self.check_output()


287 288 289
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
290
        # In practical use, 0.0 < score_threshold < 1.0
291 292 293
        self.score_threshold = 2.0


J
jerrywgz 已提交
294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350
class TestMulticlassNMSLoDInput(OpTest):
    def set_argument(self):
        self.score_threshold = 0.01

    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

        nmsed_outs, lod = lod_multiclass_nms(
            boxes, scores, background, score_threshold, nms_threshold,
            nms_top_k, keep_top_k, box_lod, normalized)
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
        nmsed_outs = np.array(nmsed_outs).astype('float32')
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

    def test_check_output(self):
        self.check_output()


351 352 353 354 355 356
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
J
jerrywgz 已提交
357
        calc_output = np.array([iou(box1, box2, True)]).astype('float32')
358 359 360 361 362
        self.assertTrue(np.allclose(calc_output, expt_output))


if __name__ == '__main__':
    unittest.main()