test_multiclass_nms_op.py 11.8 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14 15

from __future__ import print_function
16 17 18
import unittest
import numpy as np
import copy
19
from op_test import OpTest
20 21


J
jerrywgz 已提交
22
def iou(box_a, box_b, norm):
23 24 25 26 27 28 29 30 31 32 33 34
    """Apply intersection-over-union overlap between box_a and box_b
    """
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

J
jerrywgz 已提交
35 36 37 38
    area_a = (ymax_a - ymin_a + (norm == False)) * (xmax_a - xmin_a +
                                                    (norm == False))
    area_b = (ymax_b - ymin_b + (norm == False)) * (xmax_b - xmin_b +
                                                    (norm == False))
39 40 41 42 43 44 45 46
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

J
jerrywgz 已提交
47 48
    inter_area = max(xb - xa + (norm == False),
                     0.0) * max(yb - ya + (norm == False), 0.0)
49 50 51 52 53 54

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


J
jerrywgz 已提交
55 56 57 58 59 60 61
def nms(boxes,
        scores,
        score_threshold,
        nms_threshold,
        top_k=200,
        normalized=True,
        eta=1.0):
62 63 64 65 66
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
67 68 69 70 71 72
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
73 74 75 76 77 78 79 80 81
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

82
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
83
    sorted_scores = all_scores[sorted_indices]
84
    sorted_indices = selected_indices[sorted_indices]
D
dangqingqing 已提交
85
    if top_k > -1 and top_k < sorted_indices.shape[0]:
86 87 88 89 90 91 92 93 94 95 96
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
J
jerrywgz 已提交
97
                overlap = iou(boxes[idx], boxes[kept_idx], normalized)
D
dangqingqing 已提交
98
                keep = True if overlap <= adaptive_threshold else False
99 100 101 102 103 104 105 106 107 108
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
J
jerrywgz 已提交
109 110 111 112 113 114 115
                   nms_top_k, keep_top_k, normalized, shared):
    if shared:
        class_num = scores.shape[0]
        priorbox_num = scores.shape[1]
    else:
        box_num = scores.shape[0]
        class_num = scores.shape[1]
116

117
    selected_indices = {}
118 119 120
    num_det = 0
    for c in range(class_num):
        if c == background: continue
J
jerrywgz 已提交
121 122 123 124 125 126
        if shared:
            indices = nms(boxes, scores[c], score_threshold, nms_threshold,
                          nms_top_k, normalized)
        else:
            indices = nms(boxes[:, c, :], scores[:, c], score_threshold,
                          nms_threshold, nms_top_k, normalized)
127
        selected_indices[c] = indices
128 129 130 131
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
132
        for c, indices in selected_indices.items():
133
            for idx in indices:
J
jerrywgz 已提交
134 135 136 137
                if shared:
                    score_index.append((scores[c][idx], c, idx))
                else:
                    score_index.append((scores[idx][c], c, idx))
138 139 140 141

        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True)
        sorted_score_index = sorted_score_index[:keep_top_k]
142 143 144 145
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
146
        for s, c, idx in sorted_score_index:
147
            selected_indices[c].append(idx)
J
jerrywgz 已提交
148 149 150
        if not shared:
            for labels in selected_indices:
                selected_indices[labels].sort()
151
        num_det = keep_top_k
152

153
    return selected_indices, num_det
154 155


J
jerrywgz 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
def lod_multiclass_nms(boxes, scores, background, score_threshold,
                       nms_threshold, nms_top_k, keep_top_k, box_lod,
                       normalized):
    det_outs = []
    lod = []
    head = 0
    for n in range(len(box_lod[0])):
        box = boxes[head:head + box_lod[0][n]]
        score = scores[head:head + box_lod[0][n]]
        head = head + box_lod[0][n]
        nmsed_outs, nmsed_num = multiclass_nms(
            box,
            score,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=False)
        if nmsed_num == 0:
            continue
        lod.append(nmsed_num)
179
        tmp_det_out = []
J
jerrywgz 已提交
180 181 182
        for c, indices in nmsed_outs.items():
            for idx in indices:
                xmin, ymin, xmax, ymax = box[idx, c, :]
183 184 185 186
                tmp_det_out.append([c, score[idx][c], xmin, ymin, xmax, ymax])
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False)
        det_outs.extend(sorted_det_out)
187 188
    if len(lod) == 0:
        lod.append(1)
J
jerrywgz 已提交
189 190 191 192 193 194 195 196 197 198 199 200

    return det_outs, lod


def batched_multiclass_nms(boxes,
                           scores,
                           background,
                           score_threshold,
                           nms_threshold,
                           nms_top_k,
                           keep_top_k,
                           normalized=True):
201 202 203
    batch_size = scores.shape[0]

    det_outs = []
204
    lod = []
205
    for n in range(batch_size):
J
jerrywgz 已提交
206 207 208 209 210 211 212 213 214 215 216 217
        nmsed_outs, nmsed_num = multiclass_nms(
            boxes[n],
            scores[n],
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=True)
        if nmsed_num == 0:
            continue
218

J
jerrywgz 已提交
219
        lod.append(nmsed_num)
220
        tmp_det_out = []
221
        for c, indices in nmsed_outs.items():
222
            for idx in indices:
223
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
224 225 226 227 228
                tmp_det_out.append(
                    [c, scores[n][c][idx], xmin, ymin, xmax, ymax])
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False)
        det_outs.extend(sorted_det_out)
229 230
    if len(lod) == 0:
        lod += [1]
231 232 233 234
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
235 236 237
    def set_argument(self):
        self.score_threshold = 0.01

238
    def setUp(self):
239
        self.set_argument()
240
        N = 7
241
        M = 1200
242 243 244 245 246 247
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
248
        score_threshold = self.score_threshold
249

D
dangqingqing 已提交
250 251 252 253 254 255 256 257 258 259 260
        scores = np.random.random((N * M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

261 262 263
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
264 265 266 267

        nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
                                                 score_threshold, nms_threshold,
                                                 nms_top_k, keep_top_k)
268
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
D
dangqingqing 已提交
269 270 271
        nmsed_outs = np.array(nmsed_outs).astype('float32')

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
272
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
273
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
274 275 276 277 278 279 280
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
J
jerrywgz 已提交
281
            'normalized': True,
D
dangqingqing 已提交
282
        }
283 284 285 286 287

    def test_check_output(self):
        self.check_output()


288 289 290
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
291
        # In practical use, 0.0 < score_threshold < 1.0
292 293 294
        self.score_threshold = 2.0


J
jerrywgz 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
class TestMulticlassNMSLoDInput(OpTest):
    def set_argument(self):
        self.score_threshold = 0.01

    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

        nmsed_outs, lod = lod_multiclass_nms(
            boxes, scores, background, score_threshold, nms_threshold,
            nms_top_k, keep_top_k, box_lod, normalized)
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
        nmsed_outs = np.array(nmsed_outs).astype('float32')
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

    def test_check_output(self):
        self.check_output()


352 353 354 355 356 357
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
J
jerrywgz 已提交
358
        calc_output = np.array([iou(box1, box2, True)]).astype('float32')
359 360 361 362 363
        self.assertTrue(np.allclose(calc_output, expt_output))


if __name__ == '__main__':
    unittest.main()