test_multiclass_nms_op.py 23.2 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import copy
16
import unittest
17

18
import numpy as np
W
wanghuancoder 已提交
19
from eager_op_test import OpTest
20

21
import paddle
22
from paddle import _C_ops, _legacy_C_ops
23
from paddle.fluid import _non_static_mode, in_dygraph_mode
24 25 26
from paddle.fluid.layer_helper import LayerHelper


27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
def multiclass_nms3(
    bboxes,
    scores,
    rois_num=None,
    score_threshold=0.3,
    nms_top_k=1000,
    keep_top_k=100,
    nms_threshold=0.3,
    normalized=True,
    nms_eta=1.0,
    background_label=-1,
    return_index=True,
    return_rois_num=True,
    name=None,
):
42 43 44 45

    helper = LayerHelper('multiclass_nms3', **locals())

    if in_dygraph_mode():
46 47 48 49 50 51 52 53 54
        attrs = (
            score_threshold,
            nms_top_k,
            keep_top_k,
            nms_threshold,
            normalized,
            nms_eta,
            background_label,
        )
55
        output, index, nms_rois_num = _C_ops.multiclass_nms3(
56 57
            bboxes, scores, rois_num, *attrs
        )
58 59 60 61
        if not return_index:
            index = None
        return output, index, nms_rois_num
    elif _non_static_mode():
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        attrs = (
            'background_label',
            background_label,
            'score_threshold',
            score_threshold,
            'nms_top_k',
            nms_top_k,
            'nms_threshold',
            nms_threshold,
            'keep_top_k',
            keep_top_k,
            'nms_eta',
            nms_eta,
            'normalized',
            normalized,
        )
78
        output, index, nms_rois_num = _legacy_C_ops.multiclass_nms3(
79 80
            bboxes, scores, rois_num, *attrs
        )
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        if not return_index:
            index = None
        return output, index, nms_rois_num

    else:
        output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
        index = helper.create_variable_for_type_inference(dtype='int32')

        inputs = {'BBoxes': bboxes, 'Scores': scores}
        outputs = {'Out': output, 'Index': index}

        if rois_num is not None:
            inputs['RoisNum'] = rois_num

        if return_rois_num:
            nms_rois_num = helper.create_variable_for_type_inference(
97 98
                dtype='int32'
            )
99 100
            outputs['NmsRoisNum'] = nms_rois_num

101 102 103 104 105 106 107 108 109 110 111 112 113 114
        helper.append_op(
            type="multiclass_nms3",
            inputs=inputs,
            attrs={
                'background_label': background_label,
                'score_threshold': score_threshold,
                'nms_top_k': nms_top_k,
                'nms_threshold': nms_threshold,
                'keep_top_k': keep_top_k,
                'nms_eta': nms_eta,
                'normalized': normalized,
            },
            outputs=outputs,
        )
115 116 117 118 119 120 121 122
        output.stop_gradient = True
        index.stop_gradient = True
        if not return_index:
            index = None
        if not return_rois_num:
            nms_rois_num = None

        return output, nms_rois_num, index
123 124


125 126 127
def softmax(x):
    # clip to shiftx, otherwise, when calc loss with
    # log(exp(shiftx)), may get log(0)=INF
128
    shiftx = (x - np.max(x)).clip(-64.0)
129 130 131 132
    exps = np.exp(shiftx)
    return exps / np.sum(exps)


J
jerrywgz 已提交
133
def iou(box_a, box_b, norm):
134
    """Apply intersection-over-union overlap between box_a and box_b"""
135 136 137 138 139 140 141 142 143 144
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

145 146
    area_a = (ymax_a - ymin_a + (not norm)) * (xmax_a - xmin_a + (not norm))
    area_b = (ymax_b - ymin_b + (not norm)) * (xmax_b - xmin_b + (not norm))
147 148 149 150 151 152 153 154
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

155
    inter_area = max(xb - xa + (not norm), 0.0) * max(yb - ya + (not norm), 0.0)
156 157 158 159 160 161

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


162 163 164 165 166 167 168 169 170
def nms(
    boxes,
    scores,
    score_threshold,
    nms_threshold,
    top_k=200,
    normalized=True,
    eta=1.0,
):
171 172 173 174 175
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
176 177 178 179 180 181
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
182 183 184 185 186 187 188 189 190
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

191
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
192
    sorted_scores = all_scores[sorted_indices]
193
    sorted_indices = selected_indices[sorted_indices]
D
dangqingqing 已提交
194
    if top_k > -1 and top_k < sorted_indices.shape[0]:
195 196 197 198 199 200 201 202 203 204 205
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
J
jerrywgz 已提交
206
                overlap = iou(boxes[idx], boxes[kept_idx], normalized)
D
dangqingqing 已提交
207
                keep = True if overlap <= adaptive_threshold else False
208 209 210 211 212 213 214 215 216
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


217 218 219 220 221 222 223 224 225 226 227
def multiclass_nms(
    boxes,
    scores,
    background,
    score_threshold,
    nms_threshold,
    nms_top_k,
    keep_top_k,
    normalized,
    shared,
):
J
jerrywgz 已提交
228 229 230 231 232 233
    if shared:
        class_num = scores.shape[0]
        priorbox_num = scores.shape[1]
    else:
        box_num = scores.shape[0]
        class_num = scores.shape[1]
234

235
    selected_indices = {}
236 237
    num_det = 0
    for c in range(class_num):
238 239
        if c == background:
            continue
J
jerrywgz 已提交
240
        if shared:
241 242 243 244 245 246 247 248
            indices = nms(
                boxes,
                scores[c],
                score_threshold,
                nms_threshold,
                nms_top_k,
                normalized,
            )
J
jerrywgz 已提交
249
        else:
250 251 252 253 254 255 256 257
            indices = nms(
                boxes[:, c, :],
                scores[:, c],
                score_threshold,
                nms_threshold,
                nms_top_k,
                normalized,
            )
258
        selected_indices[c] = indices
259 260 261 262
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
263
        for c, indices in selected_indices.items():
264
            for idx in indices:
J
jerrywgz 已提交
265 266 267 268
                if shared:
                    score_index.append((scores[c][idx], c, idx))
                else:
                    score_index.append((scores[idx][c], c, idx))
269

270 271 272
        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True
        )
273
        sorted_score_index = sorted_score_index[:keep_top_k]
274 275 276 277
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
278
        for s, c, idx in sorted_score_index:
279
            selected_indices[c].append(idx)
J
jerrywgz 已提交
280 281 282
        if not shared:
            for labels in selected_indices:
                selected_indices[labels].sort()
283
        num_det = keep_top_k
284

285
    return selected_indices, num_det
286 287


288 289 290 291 292 293 294 295 296 297 298
def lod_multiclass_nms(
    boxes,
    scores,
    background,
    score_threshold,
    nms_threshold,
    nms_top_k,
    keep_top_k,
    box_lod,
    normalized,
):
299
    num_class = boxes.shape[1]
J
jerrywgz 已提交
300 301 302 303
    det_outs = []
    lod = []
    head = 0
    for n in range(len(box_lod[0])):
304 305 306
        if box_lod[0][n] == 0:
            lod.append(0)
            continue
307 308
        box = boxes[head : head + box_lod[0][n]]
        score = scores[head : head + box_lod[0][n]]
309
        offset = head
J
jerrywgz 已提交
310
        head = head + box_lod[0][n]
311 312 313 314 315 316 317 318 319 320 321
        nmsed_outs, nmsed_num = multiclass_nms(
            box,
            score,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=False,
        )
322 323
        lod.append(nmsed_num)

J
jerrywgz 已提交
324 325
        if nmsed_num == 0:
            continue
326
        tmp_det_out = []
J
jerrywgz 已提交
327 328 329
        for c, indices in nmsed_outs.items():
            for idx in indices:
                xmin, ymin, xmax, ymax = box[idx, c, :]
330 331 332 333 334 335 336 337 338 339 340 341 342 343
                tmp_det_out.append(
                    [
                        c,
                        score[idx][c],
                        xmin,
                        ymin,
                        xmax,
                        ymax,
                        offset * num_class + idx * num_class + c,
                    ]
                )
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False
        )
344
        det_outs.extend(sorted_det_out)
J
jerrywgz 已提交
345 346 347 348

    return det_outs, lod


349 350 351 352 353 354 355 356 357 358
def batched_multiclass_nms(
    boxes,
    scores,
    background,
    score_threshold,
    nms_threshold,
    nms_top_k,
    keep_top_k,
    normalized=True,
):
359
    batch_size = scores.shape[0]
360
    num_boxes = scores.shape[2]
361
    det_outs = []
362
    index_outs = []
363
    lod = []
364
    for n in range(batch_size):
365 366 367 368 369 370 371 372 373 374 375
        nmsed_outs, nmsed_num = multiclass_nms(
            boxes[n],
            scores[n],
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=True,
        )
376 377
        lod.append(nmsed_num)

J
jerrywgz 已提交
378 379
        if nmsed_num == 0:
            continue
380
        tmp_det_out = []
381
        for c, indices in nmsed_outs.items():
382
            for idx in indices:
383
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
384 385 386 387 388 389 390 391 392 393 394 395 396 397
                tmp_det_out.append(
                    [
                        c,
                        scores[n][c][idx],
                        xmin,
                        ymin,
                        xmax,
                        ymax,
                        idx + n * num_boxes,
                    ]
                )
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False
        )
398
        det_outs.extend(sorted_det_out)
399 400 401 402
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
403 404 405
    def set_argument(self):
        self.score_threshold = 0.01

406
    def setUp(self):
407
        self.set_argument()
408
        N = 7
409
        M = 1200
410 411 412 413 414 415
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
416
        score_threshold = self.score_threshold
417

D
dangqingqing 已提交
418 419 420 421 422 423
        scores = np.random.random((N * M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

424 425 426
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
427

428 429 430 431 432 433 434 435 436
        det_outs, lod = batched_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
        )
437 438 439 440
        lod = [1] if not det_outs else lod
        det_outs = [[-1, 0]] if not det_outs else det_outs
        det_outs = np.array(det_outs)
        nmsed_outs = det_outs[:, :-1].astype('float32')
D
dangqingqing 已提交
441 442

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
443
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
444
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
445 446 447 448 449 450 451
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
J
jerrywgz 已提交
452
            'normalized': True,
D
dangqingqing 已提交
453
        }
454 455 456 457 458

    def test_check_output(self):
        self.check_output()


459 460 461
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
462
        # In practical use, 0.0 < score_threshold < 1.0
463 464 465
        self.score_threshold = 2.0


J
jerrywgz 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
class TestMulticlassNMSLoDInput(OpTest):
    def set_argument(self):
        self.score_threshold = 0.01

    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

493 494 495 496 497 498 499 500 501 502 503
        det_outs, lod = lod_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            box_lod,
            normalized,
        )
504
        det_outs = np.array(det_outs).astype('float32')
505 506 507
        nmsed_outs = (
            det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
        )
J
jerrywgz 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

    def test_check_output(self):
        self.check_output()


528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
class TestMulticlassNMSNoBox(TestMulticlassNMSLoDInput):
    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[0, 1200, 0]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

552 553 554 555 556 557 558 559 560 561 562
        det_outs, lod = lod_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            box_lod,
            normalized,
        )
563
        det_outs = np.array(det_outs).astype('float32')
564 565 566
        nmsed_outs = (
            det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
        )
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }


584 585 586 587 588 589
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
J
jerrywgz 已提交
590
        calc_output = np.array([iou(box1, box2, True)]).astype('float32')
591
        np.testing.assert_allclose(calc_output, expt_output, rtol=1e-05)
592 593


594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616
class TestMulticlassNMS2Op(TestMulticlassNMSOp):
    def setUp(self):
        self.set_argument()
        N = 7
        M = 1200
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold

        scores = np.random.random((N * M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5

617 618 619 620 621 622 623 624 625
        det_outs, lod = batched_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
        )
626 627
        det_outs = np.array(det_outs)

628
        nmsed_outs = (
629 630 631
            det_outs[:, :-1].astype('float32')
            if len(det_outs)
            else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
632 633 634 635
        )
        index_outs = (
            det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
        )
636 637 638 639
        self.op_type = 'multiclass_nms2'
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
        self.outputs = {
            'Out': (nmsed_outs, [lod]),
640
            'Index': (index_outs, [lod]),
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
        }
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': True,
        }

    def test_check_output(self):
        self.check_output()


class TestMulticlassNMS2OpNoOutput(TestMulticlassNMS2Op):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0


class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

687 688 689 690 691 692 693 694 695 696 697
        det_outs, lod = lod_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            box_lod,
            normalized,
        )
698 699

        det_outs = np.array(det_outs)
700
        nmsed_outs = (
701 702 703
            det_outs[:, :-1].astype('float32')
            if len(det_outs)
            else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
704 705 706 707
        )
        index_outs = (
            det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
        )
708 709 710 711 712 713 714
        self.op_type = 'multiclass_nms2'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {
            'Out': (nmsed_outs, [lod]),
715
            'Index': (index_outs, [lod]),
716 717 718 719 720 721 722 723 724 725 726
        }
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

727 728 729

def test_check_output(self):
    self.check_output()
730 731 732 733 734 735 736 737 738


class TestMulticlassNMS2LoDNoOutput(TestMulticlassNMS2LoDInput):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0


739 740
class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
    def setUp(self):
741
        self.python_api = multiclass_nms3
742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762
        self.set_argument()
        N = 7
        M = 1200
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold

        scores = np.random.random((N * M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5

763 764 765 766 767 768 769 770 771
        det_outs, lod = batched_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
        )
772 773
        det_outs = np.array(det_outs)

774
        nmsed_outs = (
775 776 777
            det_outs[:, :-1].astype('float32')
            if len(det_outs)
            else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
778 779 780 781
        )
        index_outs = (
            det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
        )
782 783 784
        self.op_type = 'multiclass_nms3'
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
        self.outputs = {
785 786
            'Out': nmsed_outs,
            'Index': index_outs,
787
            'NmsRoisNum': np.array(lod).astype('int32'),
788 789 790 791 792 793 794 795 796 797 798 799
        }
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': True,
        }

    def test_check_output(self):
W
wanghuancoder 已提交
800
        self.check_output()
801 802 803 804 805 806 807 808 809


class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0


810
if __name__ == '__main__':
811
    paddle.enable_static()
812
    unittest.main()