test_multiclass_nms_op.py 7.7 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
14 15

from __future__ import print_function
16 17 18
import unittest
import numpy as np
import copy
19
from op_test import OpTest
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


def iou(box_a, box_b):
    """Apply intersection-over-union overlap between box_a and box_b
    """
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

    area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a)
    area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b)
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

    inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)

    box_a_area = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
    box_b_area = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
61 62 63 64 65 66
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
67 68 69 70 71 72 73 74 75
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

76
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
77
    sorted_scores = all_scores[sorted_indices]
D
dangqingqing 已提交
78
    if top_k > -1 and top_k < sorted_indices.shape[0]:
79 80 81 82 83 84 85 86 87 88 89 90
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
                overlap = iou(boxes[idx], boxes[kept_idx])
D
dangqingqing 已提交
91
                keep = True if overlap <= adaptive_threshold else False
92 93 94 95 96 97 98 99 100 101 102 103 104 105
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
                   nms_top_k, keep_top_k):
    class_num = scores.shape[0]
    priorbox_num = scores.shape[1]

106
    selected_indices = {}
107 108 109 110 111
    num_det = 0
    for c in range(class_num):
        if c == background: continue
        indices = nms(boxes, scores[c], score_threshold, nms_threshold,
                      nms_top_k)
112
        selected_indices[c] = indices
113 114 115 116
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
117
        for c, indices in selected_indices.items():
118 119
            for idx in indices:
                score_index.append((scores[c][idx], c, idx))
120 121 122 123

        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True)
        sorted_score_index = sorted_score_index[:keep_top_k]
124 125 126 127
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
128
        for s, c, idx in sorted_score_index:
129 130
            selected_indices[c].append(idx)
        num_det = keep_top_k
131

132
    return selected_indices, num_det
133 134 135 136 137 138 139


def batched_multiclass_nms(boxes, scores, background, score_threshold,
                           nms_threshold, nms_top_k, keep_top_k):
    batch_size = scores.shape[0]

    det_outs = []
140
    lod = []
141
    for n in range(batch_size):
142
        nmsed_outs, nmsed_num = multiclass_nms(boxes[n], scores[n], background,
143 144
                                               score_threshold, nms_threshold,
                                               nms_top_k, keep_top_k)
145
        lod.append(nmsed_num)
146 147
        if nmsed_num == 0: continue

148
        for c, indices in nmsed_outs.items():
149
            for idx in indices:
150
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
151 152
                det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax])

153 154 155 156
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
157 158 159
    def set_argument(self):
        self.score_threshold = 0.01

160
    def setUp(self):
161
        self.set_argument()
162
        N = 7
163
        M = 1200
164 165
        C = 21
        BOX_SIZE = 4
166

167 168 169 170
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
171
        score_threshold = self.score_threshold
172

D
dangqingqing 已提交
173 174 175 176 177 178 179 180 181 182 183
        scores = np.random.random((N * M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

184 185 186
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
187 188 189 190

        nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
                                                 score_threshold, nms_threshold,
                                                 nms_top_k, keep_top_k)
191
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
D
dangqingqing 已提交
192 193 194
        nmsed_outs = np.array(nmsed_outs).astype('float32')

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
195
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
196
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
197 198 199 200 201 202 203 204
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
        }
205 206 207 208 209

    def test_check_output(self):
        self.check_output()


210 211 212 213 214 215 216
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0 
        self.score_threshold = 2.0


217 218 219 220 221 222 223 224 225 226 227 228
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
        calc_output = np.array([iou(box1, box2)]).astype('float32')
        self.assertTrue(np.allclose(calc_output, expt_output))


if __name__ == '__main__':
    unittest.main()