test_multiclass_nms_op.py 24.9 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

15
import copy
16
import unittest
17

18
import numpy as np
W
wanghuancoder 已提交
19
from eager_op_test import OpTest
20

21
import paddle
22
from paddle import _C_ops, _legacy_C_ops
23
from paddle.fluid import _non_static_mode, core, in_dygraph_mode
24 25 26
from paddle.fluid.layer_helper import LayerHelper


27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
def multiclass_nms3(
    bboxes,
    scores,
    rois_num=None,
    score_threshold=0.3,
    nms_top_k=1000,
    keep_top_k=100,
    nms_threshold=0.3,
    normalized=True,
    nms_eta=1.0,
    background_label=-1,
    return_index=True,
    return_rois_num=True,
    name=None,
):
42 43 44 45

    helper = LayerHelper('multiclass_nms3', **locals())

    if in_dygraph_mode():
46 47 48 49 50 51 52 53 54
        attrs = (
            score_threshold,
            nms_top_k,
            keep_top_k,
            nms_threshold,
            normalized,
            nms_eta,
            background_label,
        )
55
        output, index, nms_rois_num = _C_ops.multiclass_nms3(
56 57
            bboxes, scores, rois_num, *attrs
        )
58 59 60 61
        if not return_index:
            index = None
        return output, index, nms_rois_num
    elif _non_static_mode():
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        attrs = (
            'background_label',
            background_label,
            'score_threshold',
            score_threshold,
            'nms_top_k',
            nms_top_k,
            'nms_threshold',
            nms_threshold,
            'keep_top_k',
            keep_top_k,
            'nms_eta',
            nms_eta,
            'normalized',
            normalized,
        )
78
        output, index, nms_rois_num = _legacy_C_ops.multiclass_nms3(
79 80
            bboxes, scores, rois_num, *attrs
        )
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        if not return_index:
            index = None
        return output, index, nms_rois_num

    else:
        output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
        index = helper.create_variable_for_type_inference(dtype='int32')

        inputs = {'BBoxes': bboxes, 'Scores': scores}
        outputs = {'Out': output, 'Index': index}

        if rois_num is not None:
            inputs['RoisNum'] = rois_num

        if return_rois_num:
            nms_rois_num = helper.create_variable_for_type_inference(
97 98
                dtype='int32'
            )
99 100
            outputs['NmsRoisNum'] = nms_rois_num

101 102 103 104 105 106 107 108 109 110 111 112 113 114
        helper.append_op(
            type="multiclass_nms3",
            inputs=inputs,
            attrs={
                'background_label': background_label,
                'score_threshold': score_threshold,
                'nms_top_k': nms_top_k,
                'nms_threshold': nms_threshold,
                'keep_top_k': keep_top_k,
                'nms_eta': nms_eta,
                'normalized': normalized,
            },
            outputs=outputs,
        )
115 116 117 118 119 120 121 122
        output.stop_gradient = True
        index.stop_gradient = True
        if not return_index:
            index = None
        if not return_rois_num:
            nms_rois_num = None

        return output, nms_rois_num, index
123 124


125 126 127
def softmax(x):
    # clip to shiftx, otherwise, when calc loss with
    # log(exp(shiftx)), may get log(0)=INF
128
    shiftx = (x - np.max(x)).clip(-64.0)
129 130 131 132
    exps = np.exp(shiftx)
    return exps / np.sum(exps)


J
jerrywgz 已提交
133
def iou(box_a, box_b, norm):
134
    """Apply intersection-over-union overlap between box_a and box_b"""
135 136 137 138 139 140 141 142 143 144
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

145 146
    area_a = (ymax_a - ymin_a + (not norm)) * (xmax_a - xmin_a + (not norm))
    area_b = (ymax_b - ymin_b + (not norm)) * (xmax_b - xmin_b + (not norm))
147 148 149 150 151 152 153 154
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

155
    inter_area = max(xb - xa + (not norm), 0.0) * max(yb - ya + (not norm), 0.0)
156 157 158 159 160 161

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


162 163 164 165 166 167 168 169 170
def nms(
    boxes,
    scores,
    score_threshold,
    nms_threshold,
    top_k=200,
    normalized=True,
    eta=1.0,
):
171 172 173 174 175
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
176 177 178 179 180 181
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
182 183 184 185 186 187 188 189 190
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

191
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
192
    sorted_scores = all_scores[sorted_indices]
193
    sorted_indices = selected_indices[sorted_indices]
D
dangqingqing 已提交
194
    if top_k > -1 and top_k < sorted_indices.shape[0]:
195 196 197 198 199 200 201 202 203 204 205
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
J
jerrywgz 已提交
206
                overlap = iou(boxes[idx], boxes[kept_idx], normalized)
D
dangqingqing 已提交
207
                keep = True if overlap <= adaptive_threshold else False
208 209 210 211 212 213 214 215 216
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


217 218 219 220 221 222 223 224 225 226 227
def multiclass_nms(
    boxes,
    scores,
    background,
    score_threshold,
    nms_threshold,
    nms_top_k,
    keep_top_k,
    normalized,
    shared,
):
J
jerrywgz 已提交
228 229 230 231 232 233
    if shared:
        class_num = scores.shape[0]
        priorbox_num = scores.shape[1]
    else:
        box_num = scores.shape[0]
        class_num = scores.shape[1]
234

235
    selected_indices = {}
236 237
    num_det = 0
    for c in range(class_num):
238 239
        if c == background:
            continue
J
jerrywgz 已提交
240
        if shared:
241 242 243 244 245 246 247 248
            indices = nms(
                boxes,
                scores[c],
                score_threshold,
                nms_threshold,
                nms_top_k,
                normalized,
            )
J
jerrywgz 已提交
249
        else:
250 251 252 253 254 255 256 257
            indices = nms(
                boxes[:, c, :],
                scores[:, c],
                score_threshold,
                nms_threshold,
                nms_top_k,
                normalized,
            )
258
        selected_indices[c] = indices
259 260 261 262
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
263
        for c, indices in selected_indices.items():
264
            for idx in indices:
J
jerrywgz 已提交
265 266 267 268
                if shared:
                    score_index.append((scores[c][idx], c, idx))
                else:
                    score_index.append((scores[idx][c], c, idx))
269

270 271 272
        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True
        )
273
        sorted_score_index = sorted_score_index[:keep_top_k]
274 275 276 277
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
278
        for s, c, idx in sorted_score_index:
279
            selected_indices[c].append(idx)
J
jerrywgz 已提交
280 281 282
        if not shared:
            for labels in selected_indices:
                selected_indices[labels].sort()
283
        num_det = keep_top_k
284

285
    return selected_indices, num_det
286 287


288 289 290 291 292 293 294 295 296 297 298
def lod_multiclass_nms(
    boxes,
    scores,
    background,
    score_threshold,
    nms_threshold,
    nms_top_k,
    keep_top_k,
    box_lod,
    normalized,
):
299
    num_class = boxes.shape[1]
J
jerrywgz 已提交
300 301 302 303
    det_outs = []
    lod = []
    head = 0
    for n in range(len(box_lod[0])):
304 305 306
        if box_lod[0][n] == 0:
            lod.append(0)
            continue
307 308
        box = boxes[head : head + box_lod[0][n]]
        score = scores[head : head + box_lod[0][n]]
309
        offset = head
J
jerrywgz 已提交
310
        head = head + box_lod[0][n]
311 312 313 314 315 316 317 318 319 320 321
        nmsed_outs, nmsed_num = multiclass_nms(
            box,
            score,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=False,
        )
322 323
        lod.append(nmsed_num)

J
jerrywgz 已提交
324 325
        if nmsed_num == 0:
            continue
326
        tmp_det_out = []
J
jerrywgz 已提交
327 328 329
        for c, indices in nmsed_outs.items():
            for idx in indices:
                xmin, ymin, xmax, ymax = box[idx, c, :]
330 331 332 333 334 335 336 337 338 339 340 341 342 343
                tmp_det_out.append(
                    [
                        c,
                        score[idx][c],
                        xmin,
                        ymin,
                        xmax,
                        ymax,
                        offset * num_class + idx * num_class + c,
                    ]
                )
        sorted_det_out = sorted(
            tmp_det_out, key=lambda tup: tup[0], reverse=False
        )
344
        det_outs.extend(sorted_det_out)
J
jerrywgz 已提交
345 346 347 348

    return det_outs, lod


349 350 351 352 353 354 355 356 357
def batched_multiclass_nms(
    boxes,
    scores,
    background,
    score_threshold,
    nms_threshold,
    nms_top_k,
    keep_top_k,
    normalized=True,
358
    gpu_logic=False,
359
):
360
    batch_size = scores.shape[0]
361
    num_boxes = scores.shape[2]
362
    det_outs = []
363
    index_outs = []
364
    lod = []
365
    for n in range(batch_size):
366 367 368 369 370 371 372 373 374 375 376
        nmsed_outs, nmsed_num = multiclass_nms(
            boxes[n],
            scores[n],
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            normalized,
            shared=True,
        )
377 378
        lod.append(nmsed_num)

J
jerrywgz 已提交
379 380
        if nmsed_num == 0:
            continue
381
        tmp_det_out = []
382
        for c, indices in nmsed_outs.items():
383
            for idx in indices:
384
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
385 386 387 388 389 390 391 392 393 394 395
                tmp_det_out.append(
                    [
                        c,
                        scores[n][c][idx],
                        xmin,
                        ymin,
                        xmax,
                        ymax,
                        idx + n * num_boxes,
                    ]
                )
396 397 398 399 400 401 402 403
        if gpu_logic:
            sorted_det_out = sorted(
                tmp_det_out, key=lambda tup: tup[1], reverse=True
            )
        else:
            sorted_det_out = sorted(
                tmp_det_out, key=lambda tup: tup[0], reverse=False
            )
404
        det_outs.extend(sorted_det_out)
405 406 407 408
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
409 410 411
    def set_argument(self):
        self.score_threshold = 0.01

412
    def setUp(self):
413
        self.set_argument()
414
        N = 7
415
        M = 1200
416 417 418 419 420 421
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
422
        score_threshold = self.score_threshold
423

D
dangqingqing 已提交
424 425 426 427 428 429
        scores = np.random.random((N * M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

430 431 432
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
433

434 435 436 437 438 439 440 441 442
        det_outs, lod = batched_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
        )
443 444 445 446
        lod = [1] if not det_outs else lod
        det_outs = [[-1, 0]] if not det_outs else det_outs
        det_outs = np.array(det_outs)
        nmsed_outs = det_outs[:, :-1].astype('float32')
D
dangqingqing 已提交
447 448

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
449
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
450
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
451 452 453 454 455 456 457
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
J
jerrywgz 已提交
458
            'normalized': True,
D
dangqingqing 已提交
459
        }
460 461 462 463 464

    def test_check_output(self):
        self.check_output()


465 466 467
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
468
        # In practical use, 0.0 < score_threshold < 1.0
469 470 471
        self.score_threshold = 2.0


J
jerrywgz 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
class TestMulticlassNMSLoDInput(OpTest):
    def set_argument(self):
        self.score_threshold = 0.01

    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

499 500 501 502 503 504 505 506 507 508 509
        det_outs, lod = lod_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            box_lod,
            normalized,
        )
510
        det_outs = np.array(det_outs).astype('float32')
511 512 513
        nmsed_outs = (
            det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
        )
J
jerrywgz 已提交
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

    def test_check_output(self):
        self.check_output()


534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
class TestMulticlassNMSNoBox(TestMulticlassNMSLoDInput):
    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[0, 1200, 0]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

558 559 560 561 562 563 564 565 566 567 568
        det_outs, lod = lod_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            box_lod,
            normalized,
        )
569
        det_outs = np.array(det_outs).astype('float32')
570 571 572
        nmsed_outs = (
            det_outs[:, :-1].astype('float32') if len(det_outs) else det_outs
        )
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589
        self.op_type = 'multiclass_nms'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {'Out': (nmsed_outs, [lod])}
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }


590 591 592 593 594 595
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
J
jerrywgz 已提交
596
        calc_output = np.array([iou(box1, box2, True)]).astype('float32')
597
        np.testing.assert_allclose(calc_output, expt_output, rtol=1e-05)
598 599


600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622
class TestMulticlassNMS2Op(TestMulticlassNMSOp):
    def setUp(self):
        self.set_argument()
        N = 7
        M = 1200
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold

        scores = np.random.random((N * M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5

623 624 625 626 627 628 629 630 631
        det_outs, lod = batched_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
        )
632 633
        det_outs = np.array(det_outs)

634
        nmsed_outs = (
635 636 637
            det_outs[:, :-1].astype('float32')
            if len(det_outs)
            else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
638 639 640 641
        )
        index_outs = (
            det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
        )
642 643 644 645
        self.op_type = 'multiclass_nms2'
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
        self.outputs = {
            'Out': (nmsed_outs, [lod]),
646
            'Index': (index_outs, [lod]),
647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692
        }
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': True,
        }

    def test_check_output(self):
        self.check_output()


class TestMulticlassNMS2OpNoOutput(TestMulticlassNMS2Op):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0


class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput):
    def setUp(self):
        self.set_argument()
        M = 1200
        C = 21
        BOX_SIZE = 4
        box_lod = [[1200]]
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
        score_threshold = self.score_threshold
        normalized = False

        scores = np.random.random((M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)

        boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
        boxes[:, :, 0] = boxes[:, :, 0] * 10
        boxes[:, :, 1] = boxes[:, :, 1] * 10
        boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
        boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10

693 694 695 696 697 698 699 700 701 702 703
        det_outs, lod = lod_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
            box_lod,
            normalized,
        )
704 705

        det_outs = np.array(det_outs)
706
        nmsed_outs = (
707 708 709
            det_outs[:, :-1].astype('float32')
            if len(det_outs)
            else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
710 711 712 713
        )
        index_outs = (
            det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
        )
714 715 716 717 718 719 720
        self.op_type = 'multiclass_nms2'
        self.inputs = {
            'BBoxes': (boxes, box_lod),
            'Scores': (scores, box_lod),
        }
        self.outputs = {
            'Out': (nmsed_outs, [lod]),
721
            'Index': (index_outs, [lod]),
722 723 724 725 726 727 728 729 730 731 732
        }
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': normalized,
        }

733 734 735

def test_check_output(self):
    self.check_output()
736 737 738 739 740 741 742 743 744


class TestMulticlassNMS2LoDNoOutput(TestMulticlassNMS2LoDInput):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0


745 746
class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
    def setUp(self):
747
        self.python_api = multiclass_nms3
748 749 750 751 752 753 754 755
        self.set_argument()
        N = 7
        M = 1200
        C = 21
        BOX_SIZE = 4
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
756
        keep_top_k = 200 if not hasattr(self, 'keep_top_k') else self.keep_top_k
757 758 759 760 761 762 763 764 765 766 767 768
        score_threshold = self.score_threshold

        scores = np.random.random((N * M, C)).astype('float32')

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5

769 770 771 772 773 774 775 776
        det_outs, lod = batched_multiclass_nms(
            boxes,
            scores,
            background,
            score_threshold,
            nms_threshold,
            nms_top_k,
            keep_top_k,
777
            gpu_logic=self.gpu_logic if hasattr(self, 'gpu_logic') else None,
778
        )
779 780
        det_outs = np.array(det_outs)

781
        nmsed_outs = (
782 783 784
            det_outs[:, :-1].astype('float32')
            if len(det_outs)
            else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2])
785 786 787 788
        )
        index_outs = (
            det_outs[:, -1:].astype('int') if len(det_outs) else det_outs
        )
789 790 791
        self.op_type = 'multiclass_nms3'
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
        self.outputs = {
792 793
            'Out': nmsed_outs,
            'Index': index_outs,
794
            'NmsRoisNum': np.array(lod).astype('int32'),
795 796 797 798 799 800 801 802 803 804 805 806
        }
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
            'normalized': True,
        }

    def test_check_output(self):
807 808
        place = paddle.CPUPlace()
        self.check_output_with_place(place)
809 810 811 812 813 814 815 816 817


class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0


818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862
@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPU(TestMulticlassNMS2Op):
    def test_check_output(self):
        place = paddle.CUDAPlace(0)
        self.check_output_with_place(place)

    def set_argument(self):
        self.score_threshold = 0.01
        self.gpu_logic = True


@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPULessOutput(TestMulticlassNMS3OpGPU):
    def set_argument(self):
        # Here set 0.08 to make output box size less than keep_top_k
        self.score_threshold = 0.08
        self.gpu_logic = True


@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPUNoOutput(TestMulticlassNMS3OpGPU):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0
        self.score_threshold = 2.0
        self.gpu_logic = True


@unittest.skipIf(
    not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPUFallback(TestMulticlassNMS3OpGPU):
    def set_argument(self):
        # Setting keep_top_k < 0 will fall back to CPU kernel
        self.score_threshold = 0.01
        self.keep_top_k = -1
        self.gpu_logic = True


863
if __name__ == '__main__':
864
    paddle.enable_static()
865
    unittest.main()