test_multiclass_nms_op.py 11.7 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14 15

from __future__ import print_function
16 17 18
import unittest
import numpy as np
import copy
19
from op_test import OpTest
20 21


J
jerrywgz 已提交
22
def iou(box_a, box_b, norm):
23 24 25 26 27 28 29 30 31 32 33 34
    """Apply intersection-over-union overlap between box_a and box_b
    """
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

J
jerrywgz 已提交
35 36 37 38
    area_a = (ymax_a - ymin_a + (norm == False)) * (xmax_a - xmin_a +
                                                    (norm == False))
    area_b = (ymax_b - ymin_b + (norm == False)) * (xmax_b - xmin_b +
                                                    (norm == False))
39 40 41 42 43 44 45 46
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

J
jerrywgz 已提交
47 48
    inter_area = max(xb - xa + (norm == False),
                     0.0) * max(yb - ya + (norm == False), 0.0)
49 50 51 52 53 54

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


J
jerrywgz 已提交
55 56 57 58 59 60 61
def nms(boxes,
        scores,
        score_threshold,
        nms_threshold,
        top_k=200,
        normalized=True,
        eta=1.0):
62 63 64 65 66
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
67 68 69 70 71 72
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
73 74 75 76 77 78 79 80 81
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

82
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
83
    sorted_scores = all_scores[sorted_indices]
D
dangqingqing 已提交
84
    if top_k > -1 and top_k < sorted_indices.shape[0]:
85 86 87 88 89 90 91 92 93 94 95
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
J
jerrywgz 已提交
96
                overlap = iou(boxes[idx], boxes[kept_idx], normalized)
D
dangqingqing 已提交
97
                keep = True if overlap <= adaptive_threshold else False
98 99 100 101 102 103 104 105 106 107
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
J
jerrywgz 已提交
108 109 110 111 112 113 114
                   nms_top_k, keep_top_k, normalized, shared):
    if shared:
        class_num = scores.shape[0]
        priorbox_num = scores.shape[1]
    else:
        box_num = scores.shape[0]
        class_num = scores.shape[1]
115

116
    selected_indices = {}
117 118 119
    num_det = 0
    for c in range(class_num):
        if c == background: continue
J
jerrywgz 已提交
120 121 122 123 124 125
        if shared:
            indices = nms(boxes, scores[c], score_threshold, nms_threshold,
                          nms_top_k, normalized)
        else:
            indices = nms(boxes[:, c, :], scores[:, c], score_threshold,
                          nms_threshold, nms_top_k, normalized)
126
        selected_indices[c] = indices
127 128 129 130
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
131
        for c, indices in selected_indices.items():
132
            for idx in indices:
J
jerrywgz 已提交
133 134 135 136
                if shared:
                    score_index.append((scores[c][idx], c, idx))
                else:
                    score_index.append((scores[idx][c], c, idx))
137 138 139 140

        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True)
        sorted_score_index = sorted_score_index[:keep_top_k]
141 142 143 144
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
145
        for s, c, idx in sorted_score_index:
146
            selected_indices[c].append(idx)
J
jerrywgz 已提交
147 148 149
        if not shared:
            for labels in selected_indices:
                selected_indices[labels].sort()
150
        num_det = keep_top_k
151

152
    return selected_indices, num_det
153 154


J
jerrywgz 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
def lod_multiclass_nms(boxes, scores, background, score_threshold,
                       nms_threshold, nms_top_k, keep_top_k, box_lod,
                       normalized):
    det_outs = []
    lod = []
    head = 0
    for n in range(len(box_lod[0])):
        box = boxes[head:head + box_lod[0][n]]
        score = scores[head:head + box_lod[0][n]]
        head = head + box_lod[0][n]
        nmsed_outs, nmsed_num = multiclass_nms(
            box,
            score,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=False)
        if nmsed_num == 0:
176
            #lod.append(1)
J
jerrywgz 已提交
177 178 179 180 181 182
            continue
        lod.append(nmsed_num)
        for c, indices in nmsed_outs.items():
            for idx in indices:
                xmin, ymin, xmax, ymax = box[idx, c, :]
                det_outs.append([c, score[idx][c], xmin, ymin, xmax, ymax])
183 184
    if len(lod) == 0:
        lod.append(1)
J
jerrywgz 已提交
185 186 187 188 189 190 191 192 193 194 195 196

    return det_outs, lod


def batched_multiclass_nms(boxes,
                           scores,
                           background,
                           score_threshold,
                           nms_threshold,
                           nms_top_k,
                           keep_top_k,
                           normalized=True):
197 198 199
    batch_size = scores.shape[0]

    det_outs = []
200
    lod = []
201
    for n in range(batch_size):
J
jerrywgz 已提交
202 203 204 205 206 207 208 209 210 211 212 213
        nmsed_outs, nmsed_num = multiclass_nms(
            boxes[n],
            scores[n],
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=True)
        if nmsed_num == 0:
            continue
214

J
jerrywgz 已提交
215
        lod.append(nmsed_num)
216
        tmp_det_out = []
217
        for c, indices in nmsed_outs.items():
218
            for idx in indices:
219
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
220 221 222 223 224
                tmp_det_out.append(
                    [c, scores[n][c][idx], xmin, ymin, xmax, ymax])
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False)
        det_outs.extend(sorted_det_out)
225 226
    if len(lod) == 0:
        lod += [1]
227 228 229 230
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
231 232 233
    def set_argument(self):
        self.score_threshold = 0.01

234
    def setUp(self):
235
        self.set_argument()
236
        N = 7
237
        M = 1200
238 239 240 241 242 243
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
244
        score_threshold = self.score_threshold
245

D
dangqingqing 已提交
246 247 248 249 250 251 252 253 254 255 256
        scores = np.random.random((N * M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

257 258 259
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
260 261 262 263

        nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
                                                 score_threshold, nms_threshold,
                                                 nms_top_k, keep_top_k)
264
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
D
dangqingqing 已提交
265 266 267
        nmsed_outs = np.array(nmsed_outs).astype('float32')

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
268
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
269
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
270 271 272 273 274 275 276
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
J
jerrywgz 已提交
277
            'normalized': True,
D
dangqingqing 已提交
278
        }
279 280 281 282 283

    def test_check_output(self):
        self.check_output()


284 285 286
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
287
        # In practical use, 0.0 < score_threshold < 1.0
288 289 290
        self.score_threshold = 2.0


J
jerrywgz 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
class TestMulticlassNMSLoDInput(OpTest):
    def set_argument(self):
        self.score_threshold = 0.01

    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

        nmsed_outs, lod = lod_multiclass_nms(
            boxes, scores, background, score_threshold, nms_threshold,
            nms_top_k, keep_top_k, box_lod, normalized)
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
        nmsed_outs = np.array(nmsed_outs).astype('float32')
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

    def test_check_output(self):
        self.check_output()


348 349 350 351 352 353
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
J
jerrywgz 已提交
354
        calc_output = np.array([iou(box1, box2, True)]).astype('float32')
355 356 357 358 359
        self.assertTrue(np.allclose(calc_output, expt_output))


if __name__ == '__main__':
    unittest.main()