test_multiclass_nms_op.py 7.7 KB
Newer Older
1
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import unittest
import numpy as np
import copy
from op_test import OpTest


def iou(box_a, box_b):
    """Apply intersection-over-union overlap between box_a and box_b
    """
    xmin_a = min(box_a[0], box_a[2])
    ymin_a = min(box_a[1], box_a[3])
    xmax_a = max(box_a[0], box_a[2])
    ymax_a = max(box_a[1], box_a[3])

    xmin_b = min(box_b[0], box_b[2])
    ymin_b = min(box_b[1], box_b[3])
    xmax_b = max(box_b[0], box_b[2])
    ymax_b = max(box_b[1], box_b[3])

    area_a = (ymax_a - ymin_a) * (xmax_a - xmin_a)
    area_b = (ymax_b - ymin_b) * (xmax_b - xmin_b)
    if area_a <= 0 and area_b <= 0:
        return 0.0

    xa = max(xmin_a, xmin_b)
    ya = max(ymin_a, ymin_b)
    xb = min(xmax_a, xmax_b)
    yb = min(ymax_a, ymax_b)

    inter_area = max(xb - xa, 0.0) * max(yb - ya, 0.0)

    box_a_area = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1])
    box_b_area = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1])

    iou_ratio = inter_area / (area_a + area_b - inter_area)

    return iou_ratio


def nms(boxes, scores, score_threshold, nms_threshold, top_k=200, eta=1.0):
    """Apply non-maximum suppression at test time to avoid detecting too many
    overlapping bounding boxes for a given object.
    Args:
        boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
        scores: (tensor) The class predscores for the img, Shape:[num_priors].
59 60 61 62 63 64
        score_threshold: (float) The confidence thresh for filtering low
            confidence boxes.
        nms_threshold: (float) The overlap thresh for suppressing unnecessary
            boxes.
        top_k: (int) The maximum number of box preds to consider.
        eta: (float) The parameter for adaptive NMS.
65 66 67 68 69 70 71 72 73
    Return:
        The indices of the kept boxes with respect to num_priors.
    """
    all_scores = copy.deepcopy(scores)
    all_scores = all_scores.flatten()
    selected_indices = np.argwhere(all_scores > score_threshold)
    selected_indices = selected_indices.flatten()
    all_scores = all_scores[selected_indices]

74
    sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort')
75
    sorted_scores = all_scores[sorted_indices]
D
dangqingqing 已提交
76
    if top_k > -1 and top_k < sorted_indices.shape[0]:
77 78 79 80 81 82 83 84 85 86 87 88
        sorted_indices = sorted_indices[:top_k]
        sorted_scores = sorted_scores[:top_k]

    selected_indices = []
    adaptive_threshold = nms_threshold
    for i in range(sorted_scores.shape[0]):
        idx = sorted_indices[i]
        keep = True
        for k in range(len(selected_indices)):
            if keep:
                kept_idx = selected_indices[k]
                overlap = iou(boxes[idx], boxes[kept_idx])
D
dangqingqing 已提交
89
                keep = True if overlap <= adaptive_threshold else False
90 91 92 93 94 95 96 97 98 99 100 101 102 103
            else:
                break
        if keep:
            selected_indices.append(idx)
        if keep and eta < 1 and adaptive_threshold > 0.5:
            adaptive_threshold *= eta
    return selected_indices


def multiclass_nms(boxes, scores, background, score_threshold, nms_threshold,
                   nms_top_k, keep_top_k):
    class_num = scores.shape[0]
    priorbox_num = scores.shape[1]

104
    selected_indices = {}
105 106 107 108 109
    num_det = 0
    for c in range(class_num):
        if c == background: continue
        indices = nms(boxes, scores[c], score_threshold, nms_threshold,
                      nms_top_k)
110
        selected_indices[c] = indices
111 112 113 114
        num_det += len(indices)

    if keep_top_k > -1 and num_det > keep_top_k:
        score_index = []
115 116 117
        for c, indices in selected_indices.iteritems():
            for idx in indices:
                score_index.append((scores[c][idx], c, idx))
118 119 120 121

        sorted_score_index = sorted(
            score_index, key=lambda tup: tup[0], reverse=True)
        sorted_score_index = sorted_score_index[:keep_top_k]
122 123 124 125
        selected_indices = {}

        for _, c, _ in sorted_score_index:
            selected_indices[c] = []
126
        for s, c, idx in sorted_score_index:
127 128
            selected_indices[c].append(idx)
        num_det = keep_top_k
129

130
    return selected_indices, num_det
131 132 133 134 135 136 137 138 139


def batched_multiclass_nms(boxes, scores, background, score_threshold,
                           nms_threshold, nms_top_k, keep_top_k):
    batch_size = scores.shape[0]

    det_outs = []
    lod = [0]
    for n in range(batch_size):
140
        nmsed_outs, nmsed_num = multiclass_nms(boxes[n], scores[n], background,
141 142 143 144 145 146 147
                                               score_threshold, nms_threshold,
                                               nms_top_k, keep_top_k)
        lod.append(lod[-1] + nmsed_num)
        if nmsed_num == 0: continue

        for c, indices in nmsed_outs.iteritems():
            for idx in indices:
148
                xmin, ymin, xmax, ymax = boxes[n][idx][:]
149 150
                det_outs.append([c, scores[n][c][idx], xmin, ymin, xmax, ymax])

151 152 153 154
    return det_outs, lod


class TestMulticlassNMSOp(OpTest):
155 156 157
    def set_argument(self):
        self.score_threshold = 0.01

158
    def setUp(self):
159
        self.set_argument()
160
        N = 7
161
        M = 1200
162 163
        C = 21
        BOX_SIZE = 4
164

165 166 167 168
        background = 0
        nms_threshold = 0.3
        nms_top_k = 400
        keep_top_k = 200
169
        score_threshold = self.score_threshold
170

D
dangqingqing 已提交
171 172 173 174 175 176 177 178 179 180 181
        scores = np.random.random((N * M, C)).astype('float32')

        def softmax(x):
            shiftx = x - np.max(x).clip(-64.)
            exps = np.exp(shiftx)
            return exps / np.sum(exps)

        scores = np.apply_along_axis(softmax, 1, scores)
        scores = np.reshape(scores, (N, M, C))
        scores = np.transpose(scores, (0, 2, 1))

182 183 184
        boxes = np.random.random((N, M, BOX_SIZE)).astype('float32')
        boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5
        boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5
185 186 187 188

        nmsed_outs, lod = batched_multiclass_nms(boxes, scores, background,
                                                 score_threshold, nms_threshold,
                                                 nms_top_k, keep_top_k)
189
        nmsed_outs = [-1] if not nmsed_outs else nmsed_outs
D
dangqingqing 已提交
190 191 192
        nmsed_outs = np.array(nmsed_outs).astype('float32')

        self.op_type = 'multiclass_nms'
D
dangqingqing 已提交
193
        self.inputs = {'BBoxes': boxes, 'Scores': scores}
194
        self.outputs = {'Out': (nmsed_outs, [lod])}
D
dangqingqing 已提交
195 196 197 198 199 200 201 202
        self.attrs = {
            'background_label': 0,
            'nms_threshold': nms_threshold,
            'nms_top_k': nms_top_k,
            'keep_top_k': keep_top_k,
            'score_threshold': score_threshold,
            'nms_eta': 1.0,
        }
203 204 205 206 207

    def test_check_output(self):
        self.check_output()


208 209 210 211 212 213 214
class TestMulticlassNMSOpNoOutput(TestMulticlassNMSOp):
    def set_argument(self):
        # Here set 2.0 to test the case there is no outputs.
        # In practical use, 0.0 < score_threshold < 1.0 
        self.score_threshold = 2.0


215 216 217 218 219 220 221 222 223 224 225 226
class TestIOU(unittest.TestCase):
    def test_iou(self):
        box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
        box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32')

        expt_output = np.array([2.0 / 16.0]).astype('float32')
        calc_output = np.array([iou(box1, box2)]).astype('float32')
        self.assertTrue(np.allclose(calc_output, expt_output))


if __name__ == '__main__':
    unittest.main()