post_process.py 27.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
qingqing01 已提交
15 16 17 18 19
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
M
Manuel Garcia 已提交
20
from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly
F
FlyingQianMM 已提交
21
from ppdet.modeling.layers import TTFBox
22
from .transformers import bbox_cxcywh_to_xyxy
W
wangguanzhong 已提交
23 24 25 26
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
Q
qingqing01 已提交
27

28
__all__ = [
29 30
    'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
    'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess',
F
FL77N 已提交
31
    'DETRBBoxPostProcess', 'SparsePostProcess'
32
]
F
Feng Ni 已提交
33

Q
qingqing01 已提交
34 35

@register
C
cnn 已提交
36
class BBoxPostProcess(nn.Layer):
37
    __shared__ = ['num_classes']
Q
qingqing01 已提交
38 39
    __inject__ = ['decode', 'nms']

40
    def __init__(self, num_classes=80, decode=None, nms=None):
Q
qingqing01 已提交
41
        super(BBoxPostProcess, self).__init__()
42
        self.num_classes = num_classes
Q
qingqing01 已提交
43 44 45
        self.decode = decode
        self.nms = nms

C
cnn 已提交
46
    def forward(self, head_out, rois, im_shape, scale_factor):
47 48 49
        """
        Decode the bbox and do NMS if needed. 

F
Feng Ni 已提交
50 51 52 53 54
        Args:
            head_out (tuple): bbox_pred and cls_prob of bbox_head output.
            rois (tuple): roi and rois_num of rpn_head output.
            im_shape (Tensor): The shape of the input image.
            scale_factor (Tensor): The scale factor of the input image.
55
        Returns:
F
Feng Ni 已提交
56 57 58 59 60
            bbox_pred (Tensor): The output prediction with shape [N, 6], including
                labels, scores and bboxes. The size of bboxes are corresponding
                to the input image, the bboxes may be used in other branch.
            bbox_num (Tensor): The number of prediction boxes of each batch with
                shape [1], and is N.
61
        """
F
Feng Ni 已提交
62 63
        if self.nms is not None:
            bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
64
            bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
F
Feng Ni 已提交
65 66 67
        else:
            bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
                                              scale_factor)
Q
qingqing01 已提交
68 69
        return bbox_pred, bbox_num

70 71 72
    def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
        """
        Rescale, clip and filter the bbox from the output of NMS to 
F
Feng Ni 已提交
73 74 75 76
        get final prediction. 
        
        Notes:
        Currently only support bs = 1.
77 78

        Args:
G
Guanghua Yu 已提交
79
            bboxes (Tensor): The output bboxes with shape [N, 6] after decode
F
Feng Ni 已提交
80 81 82 83 84
                and NMS, including labels, scores and bboxes.
            bbox_num (Tensor): The number of prediction boxes of each batch with
                shape [1], and is N.
            im_shape (Tensor): The shape of the input image.
            scale_factor (Tensor): The scale factor of the input image.
85
        Returns:
F
Feng Ni 已提交
86 87
            pred_result (Tensor): The final prediction results with shape [N, 6]
                including labels, scores and bboxes.
88
        """
W
wangguanzhong 已提交
89

90 91 92
        bboxes_list = []
        bbox_num_list = []
        id_start = 0
W
wangguanzhong 已提交
93 94 95 96 97
        fake_bboxes = paddle.to_tensor(
            np.array(
                [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
        fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))

98 99 100
        # add fake bbox when output is empty for each batch
        for i in range(bbox_num.shape[0]):
            if bbox_num[i] == 0:
W
wangguanzhong 已提交
101 102
                bboxes_i = fake_bboxes
                bbox_num_i = fake_bbox_num
103 104 105 106 107 108 109 110
            else:
                bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
                bbox_num_i = bbox_num[i]
                id_start += bbox_num[i]
            bboxes_list.append(bboxes_i)
            bbox_num_list.append(bbox_num_i)
        bboxes = paddle.concat(bboxes_list)
        bbox_num = paddle.concat(bbox_num_list)
W
wangguanzhong 已提交
111

112 113 114 115 116 117 118 119
        origin_shape = paddle.floor(im_shape / scale_factor + 0.5)

        origin_shape_list = []
        scale_factor_list = []
        # scale_factor: scale_y, scale_x
        for i in range(bbox_num.shape[0]):
            expand_shape = paddle.expand(origin_shape[i:i + 1, :],
                                         [bbox_num[i], 2])
G
Guanghua Yu 已提交
120
            scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
            scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
            expand_scale = paddle.expand(scale, [bbox_num[i], 4])
            origin_shape_list.append(expand_shape)
            scale_factor_list.append(expand_scale)

        self.origin_shape_list = paddle.concat(origin_shape_list)
        scale_factor_list = paddle.concat(scale_factor_list)

        # bboxes: [N, 6], label, score, bbox
        pred_label = bboxes[:, 0:1]
        pred_score = bboxes[:, 1:2]
        pred_bbox = bboxes[:, 2:]
        # rescale bbox to original image
        scaled_bbox = pred_bbox / scale_factor_list
        origin_h = self.origin_shape_list[:, 0]
        origin_w = self.origin_shape_list[:, 1]
        zeros = paddle.zeros_like(origin_h)
        # clip bbox to [0, original_size]
        x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
        y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
        x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
        y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
        pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
        # filter empty bbox
        keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
        keep_mask = paddle.unsqueeze(keep_mask, [1])
        pred_label = paddle.where(keep_mask, pred_label,
                                  paddle.ones_like(pred_label) * -1)
        pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
        return pred_result

    def get_origin_shape(self, ):
        return self.origin_shape_list

Q
qingqing01 已提交
155 156 157

@register
class MaskPostProcess(object):
W
wangguanzhong 已提交
158 159 160 161 162 163 164
    """
    refer to:
    https://github.com/facebookresearch/detectron2/layers/mask_ops.py

    Get Mask output according to the output from model
    """

165
    def __init__(self, binary_thresh=0.5):
Q
qingqing01 已提交
166 167 168
        super(MaskPostProcess, self).__init__()
        self.binary_thresh = binary_thresh

169
    def paste_mask(self, masks, boxes, im_h, im_w):
F
Feng Ni 已提交
170 171 172
        """
        Paste the mask prediction to the original image.
        """
173

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
        x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
        masks = paddle.unsqueeze(masks, [0, 1])
        img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
        img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
        img_y = (img_y - y0) / (y1 - y0) * 2 - 1
        img_x = (img_x - x0) / (x1 - x0) * 2 - 1
        img_x = paddle.unsqueeze(img_x, [1])
        img_y = paddle.unsqueeze(img_y, [2])
        N = boxes.shape[0]

        gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
        gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
        grid = paddle.stack([gx, gy], axis=3)
        img_masks = F.grid_sample(masks, grid, align_corners=False)
        return img_masks[:, 0]

    def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
        """
F
Feng Ni 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204
        Decode the mask_out and paste the mask to the origin image.

        Args:
            mask_out (Tensor): mask_head output with shape [N, 28, 28].
            bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
                and NMS, including labels, scores and bboxes.
            bbox_num (Tensor): The number of prediction boxes of each batch with
                shape [1], and is N.
            origin_shape (Tensor): The origin shape of the input image, the tensor
                shape is [N, 2], and each row is [h, w].
        Returns:
            pred_result (Tensor): The final prediction mask results with shape
                [N, h, w] in binary mask style.
205 206
        """
        num_mask = mask_out.shape[0]
G
Guanghua Yu 已提交
207 208
        origin_shape = paddle.cast(origin_shape, 'int32')
        # TODO: support bs > 1 and mask output dtype is bool
209
        pred_result = paddle.zeros(
G
Guanghua Yu 已提交
210
            [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
211
        if bbox_num == 1 and bboxes[0][0] == -1:
G
Guanghua Yu 已提交
212 213
            return pred_result

214
        # TODO: optimize chunk paste
G
Guanghua Yu 已提交
215
        pred_result = []
216
        for i in range(bboxes.shape[0]):
G
Guanghua Yu 已提交
217
            im_h, im_w = origin_shape[i][0], origin_shape[i][1]
218 219 220
            pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
                                        im_w)
            pred_mask = pred_mask >= self.binary_thresh
G
Guanghua Yu 已提交
221 222 223
            pred_mask = paddle.cast(pred_mask, 'int32')
            pred_result.append(pred_mask)
        pred_result = paddle.concat(pred_result)
224
        return pred_result
F
Feng Ni 已提交
225 226 227 228 229 230 231 232 233 234 235 236


@register
class FCOSPostProcess(object):
    __inject__ = ['decode', 'nms']

    def __init__(self, decode=None, nms=None):
        super(FCOSPostProcess, self).__init__()
        self.decode = decode
        self.nms = nms

    def __call__(self, fcos_head_outs, scale_factor):
F
Feng Ni 已提交
237 238 239
        """
        Decode the bbox and do NMS in FCOS.
        """
F
Feng Ni 已提交
240 241 242 243 244
        locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
        bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
                                    centerness, scale_factor)
        bbox_pred, bbox_num, _ = self.nms(bboxes, score)
        return bbox_pred, bbox_num
C
cnn 已提交
245 246 247


@register
C
cnn 已提交
248
class S2ANetBBoxPostProcess(nn.Layer):
249
    __shared__ = ['num_classes']
C
cnn 已提交
250 251
    __inject__ = ['nms']

252
    def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0, nms=None):
C
cnn 已提交
253
        super(S2ANetBBoxPostProcess, self).__init__()
254
        self.num_classes = num_classes
255
        self.nms_pre = nms_pre
C
cnn 已提交
256 257 258
        self.min_bbox_size = min_bbox_size
        self.nms = nms
        self.origin_shape_list = []
C
cnn 已提交
259 260 261 262 263
        self.fake_pred_cls_score_bbox = paddle.to_tensor(
            np.array(
                [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
                dtype='float32'))
        self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
C
cnn 已提交
264

C
cnn 已提交
265
    def forward(self, pred_scores, pred_bboxes):
C
cnn 已提交
266 267 268 269 270 271
        """
        pred_scores : [N, M]  score
        pred_bboxes : [N, 5]  xc, yc, w, h, a
        im_shape : [N, 2]  im_shape
        scale_factor : [N, 2]  scale_factor
        """
C
cnn 已提交
272 273
        pred_ploys0 = rbox2poly(pred_bboxes)
        pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
C
cnn 已提交
274 275

        # pred_scores [NA, 16] --> [16, NA]
C
cnn 已提交
276 277
        pred_scores0 = paddle.transpose(pred_scores, [1, 0])
        pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
C
cnn 已提交
278

279 280 281 282
        pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
                                                    self.num_classes)
        # Prevent empty bbox_pred from decode or NMS.
        # Bboxes and score before NMS may be empty due to the score threshold.
C
cnn 已提交
283 284 285 286 287 288
        if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
                1] <= 1:
            pred_cls_score_bbox = self.fake_pred_cls_score_bbox
            bbox_num = self.fake_bbox_num

        pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
289
        return pred_cls_score_bbox, bbox_num
C
cnn 已提交
290

291
    def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
C
cnn 已提交
292 293 294 295
        """
        Rescale, clip and filter the bbox from the output of NMS to
        get final prediction.
        Args:
296
            bboxes(Tensor): bboxes [N, 10]
C
cnn 已提交
297 298 299 300 301 302 303 304 305 306
            bbox_num(Tensor): bbox_num
            im_shape(Tensor): [1 2]
            scale_factor(Tensor): [1 2]
        Returns:
            bbox_pred(Tensor): The output is the prediction with shape [N, 8]
                               including labels, scores and bboxes. The size of
                               bboxes are corresponding to the original image.
        """
        origin_shape = paddle.floor(im_shape / scale_factor + 0.5)

307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
        origin_shape_list = []
        scale_factor_list = []
        # scale_factor: scale_y, scale_x
        for i in range(bbox_num.shape[0]):
            expand_shape = paddle.expand(origin_shape[i:i + 1, :],
                                         [bbox_num[i], 2])
            scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
            scale = paddle.concat([
                scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
                scale_y
            ])
            expand_scale = paddle.expand(scale, [bbox_num[i], 8])
            origin_shape_list.append(expand_shape)
            scale_factor_list.append(expand_scale)

        origin_shape_list = paddle.concat(origin_shape_list)
        scale_factor_list = paddle.concat(scale_factor_list)

        # bboxes: [N, 10], label, score, bbox
        pred_label_score = bboxes[:, 0:2]
C
cnn 已提交
327
        pred_bbox = bboxes[:, 2:]
328 329

        # rescale bbox to original image
C
cnn 已提交
330
        pred_bbox = pred_bbox.reshape([-1, 8])
331 332 333
        scaled_bbox = pred_bbox / scale_factor_list
        origin_h = origin_shape_list[:, 0]
        origin_w = origin_shape_list[:, 1]
C
cnn 已提交
334

335
        bboxes = scaled_bbox
C
cnn 已提交
336
        zeros = paddle.zeros_like(origin_h)
C
cnn 已提交
337 338 339 340 341 342 343 344
        x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
        y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
        x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
        y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
        x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
        y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
        x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
        y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
345 346 347
        pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
        pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
        return pred_result
348 349 350


@register
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
class JDEBBoxPostProcess(nn.Layer):
    __shared__ = ['num_classes']
    __inject__ = ['decode', 'nms']

    def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True):
        super(JDEBBoxPostProcess, self).__init__()
        self.num_classes = num_classes
        self.decode = decode
        self.nms = nms
        self.return_idx = return_idx

        self.fake_bbox_pred = paddle.to_tensor(
            np.array(
                [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
        self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
        self.fake_nms_keep_idx = paddle.to_tensor(
            np.array(
                [[0]], dtype='int32'))

        self.fake_yolo_boxes_out = paddle.to_tensor(
            np.array(
                [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
        self.fake_yolo_scores_out = paddle.to_tensor(
            np.array(
                [[[0.0]]], dtype='float32'))
        self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))

G
George Ni 已提交
378
    def forward(self, head_out, anchors):
379 380 381 382 383 384 385 386 387 388 389 390 391 392
        """
        Decode the bbox and do NMS for JDE model. 

        Args:
            head_out (list): Bbox_pred and cls_prob of bbox_head output.
            anchors (list): Anchors of JDE model.

        Returns:
            boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'. 
            bbox_pred (Tensor): The output is the prediction with shape [N, 6]
                including labels, scores and bboxes.
            bbox_num (Tensor): The number of prediction of each batch with shape [N].
            nms_keep_idx (Tensor): The index of kept bboxes after NMS. 
        """
393
        boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors)
394

395 396 397 398 399 400 401 402 403 404 405 406 407 408
        if len(boxes_idx) == 0:
            boxes_idx = self.fake_boxes_idx
            yolo_boxes_out = self.fake_yolo_boxes_out
            yolo_scores_out = self.fake_yolo_scores_out
        else:
            yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx)
            # TODO: only support bs=1 now
            yolo_boxes_out = paddle.reshape(
                yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4])
            yolo_scores_out = paddle.reshape(
                yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
            boxes_idx = boxes_idx[:, 1:]

        if self.return_idx:
G
George Ni 已提交
409 410 411 412 413 414
            bbox_pred, bbox_num, nms_keep_idx = self.nms(
                yolo_boxes_out, yolo_scores_out, self.num_classes)
            if bbox_pred.shape[0] == 0:
                bbox_pred = self.fake_bbox_pred
                bbox_num = self.fake_bbox_num
                nms_keep_idx = self.fake_nms_keep_idx
415 416
            return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
        else:
G
George Ni 已提交
417 418 419 420 421 422
            bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
                                              self.num_classes)
            if bbox_pred.shape[0] == 0:
                bbox_pred = self.fake_bbox_pred
                bbox_num = self.fake_bbox_num
            return _, bbox_pred, bbox_num, _
F
FlyingQianMM 已提交
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441


@register
class CenterNetPostProcess(TTFBox):
    """
    Postprocess the model outputs to get final prediction:
        1. Do NMS for heatmap to get top `max_per_img` bboxes.
        2. Decode bboxes using center offset and box size.
        3. Rescale decoded bboxes reference to the origin image shape.

    Args:
        max_per_img(int): the maximum number of predicted objects in a image,
            500 by default.
        down_ratio(int): the down ratio from images to heatmap, 4 by default.
        regress_ltrb (bool): whether to regress left/top/right/bottom or
            width/height for a box, true by default.
        for_mot (bool): whether return other features used in tracking model.
    """

W
wangguanzhong 已提交
442
    __shared__ = ['down_ratio', 'for_mot']
F
FlyingQianMM 已提交
443 444 445 446 447 448 449 450 451 452 453 454 455 456

    def __init__(self,
                 max_per_img=500,
                 down_ratio=4,
                 regress_ltrb=True,
                 for_mot=False):
        super(TTFBox, self).__init__()
        self.max_per_img = max_per_img
        self.down_ratio = down_ratio
        self.regress_ltrb = regress_ltrb
        self.for_mot = for_mot

    def __call__(self, hm, wh, reg, im_shape, scale_factor):
        heat = self._simple_nms(hm)
457
        scores, inds, topk_clses, ys, xs = self._topk(heat)
F
Feng Ni 已提交
458 459
        scores = scores.unsqueeze(1)
        clses = topk_clses.unsqueeze(1)
F
FlyingQianMM 已提交
460 461 462 463

        reg_t = paddle.transpose(reg, [0, 2, 3, 1])
        # Like TTFBox, batch size is 1.
        # TODO: support batch size > 1
F
Feng Ni 已提交
464
        reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
F
FlyingQianMM 已提交
465 466 467 468 469 470 471
        reg = paddle.gather(reg, inds)
        xs = paddle.cast(xs, 'float32')
        ys = paddle.cast(ys, 'float32')
        xs = xs + reg[:, 0:1]
        ys = ys + reg[:, 1:2]

        wh_t = paddle.transpose(wh, [0, 2, 3, 1])
F
Feng Ni 已提交
472
        wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
F
FlyingQianMM 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485
        wh = paddle.gather(wh, inds)

        if self.regress_ltrb:
            x1 = xs - wh[:, 0:1]
            y1 = ys - wh[:, 1:2]
            x2 = xs + wh[:, 2:3]
            y2 = ys + wh[:, 3:4]
        else:
            x1 = xs - wh[:, 0:1] / 2
            y1 = ys - wh[:, 1:2] / 2
            x2 = xs + wh[:, 0:1] / 2
            y2 = ys + wh[:, 1:2] / 2

486
        n, c, feat_h, feat_w = hm.shape[:]
F
FlyingQianMM 已提交
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
        padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
        padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
        x1 = x1 * self.down_ratio
        y1 = y1 * self.down_ratio
        x2 = x2 * self.down_ratio
        y2 = y2 * self.down_ratio

        x1 = x1 - padw
        y1 = y1 - padh
        x2 = x2 - padw
        y2 = y2 - padh

        bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
        scale_y = scale_factor[:, 0:1]
        scale_x = scale_factor[:, 1:2]
        scale_expand = paddle.concat(
            [scale_x, scale_y, scale_x, scale_y], axis=1)
F
Feng Ni 已提交
504
        boxes_shape = bboxes.shape[:]
F
FlyingQianMM 已提交
505 506 507 508
        scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
        bboxes = paddle.divide(bboxes, scale_expand)
        if self.for_mot:
            results = paddle.concat([bboxes, scores, clses], axis=1)
509
            return results, inds, topk_clses
F
FlyingQianMM 已提交
510 511
        else:
            results = paddle.concat([clses, scores, bboxes], axis=1)
512
            return results, paddle.shape(results)[0:1], topk_clses
513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555


@register
class DETRBBoxPostProcess(object):
    __shared__ = ['num_classes', 'use_focal_loss']
    __inject__ = []

    def __init__(self,
                 num_classes=80,
                 num_top_queries=100,
                 use_focal_loss=False):
        super(DETRBBoxPostProcess, self).__init__()
        self.num_classes = num_classes
        self.num_top_queries = num_top_queries
        self.use_focal_loss = use_focal_loss

    def __call__(self, head_out, im_shape, scale_factor):
        """
        Decode the bbox.

        Args:
            head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
            im_shape (Tensor): The shape of the input image.
            scale_factor (Tensor): The scale factor of the input image.
        Returns:
            bbox_pred (Tensor): The output prediction with shape [N, 6], including
                labels, scores and bboxes. The size of bboxes are corresponding
                to the input image, the bboxes may be used in other branch.
            bbox_num (Tensor): The number of prediction boxes of each batch with
                shape [bs], and is N.
        """
        bboxes, logits, masks = head_out

        bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
        origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
        img_h, img_w = origin_shape.unbind(1)
        origin_shape = paddle.stack(
            [img_w, img_h, img_w, img_h], axis=-1).unsqueeze(0)
        bbox_pred *= origin_shape

        scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
            logits)[:, :, :-1]

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
        if not self.use_focal_loss:
            scores, labels = scores.max(-1), scores.argmax(-1)
            if scores.shape[1] > self.num_top_queries:
                scores, index = paddle.topk(
                    scores, self.num_top_queries, axis=-1)
                labels = paddle.stack(
                    [paddle.gather(l, i) for l, i in zip(labels, index)])
                bbox_pred = paddle.stack(
                    [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
        else:
            scores, index = paddle.topk(
                scores.reshape([logits.shape[0], -1]),
                self.num_top_queries,
                axis=-1)
            labels = index % logits.shape[2]
            index = index // logits.shape[2]
572 573 574 575 576 577 578 579 580 581 582 583 584
            bbox_pred = paddle.stack(
                [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])

        bbox_pred = paddle.concat(
            [
                labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
                bbox_pred
            ],
            axis=-1)
        bbox_num = paddle.to_tensor(
            bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
        bbox_pred = bbox_pred.reshape([-1, 6])
        return bbox_pred, bbox_num
F
FL77N 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671


@register
class SparsePostProcess(object):
    __shared__ = ['num_classes']

    def __init__(self, num_proposals, num_classes=80):
        super(SparsePostProcess, self).__init__()
        self.num_classes = num_classes
        self.num_proposals = num_proposals

    def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
        """
        Arguments:
            box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
                The tensor predicts the classification probability for each proposal.
            box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
                The tensor predicts 4-vector (x,y,w,h) box
                regression values for every proposal
            scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of  per img
            img_whwh (Tensor): tensors of shape [batch_size, 4]
        Returns:
            bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
            [label, confidence, xmin, ymin, xmax, ymax]
            bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
        """
        assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)

        img_wh = img_whwh[:, :2]

        scores = F.sigmoid(box_cls)
        labels = paddle.arange(0, self.num_classes). \
            unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)

        classes_all = []
        scores_all = []
        boxes_all = []
        for i, (scores_per_image,
                box_pred_per_image) in enumerate(zip(scores, box_pred)):

            scores_per_image, topk_indices = scores_per_image.flatten(
                0, 1).topk(
                    self.num_proposals, sorted=False)
            labels_per_image = paddle.gather(labels, topk_indices, axis=0)

            box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
                [1, self.num_classes, 1]).reshape([-1, 4])
            box_pred_per_image = paddle.gather(
                box_pred_per_image, topk_indices, axis=0)

            classes_all.append(labels_per_image)
            scores_all.append(scores_per_image)
            boxes_all.append(box_pred_per_image)

        bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
        boxes_final = []

        for i in range(len(scale_factor_wh)):
            classes = classes_all[i]
            boxes = boxes_all[i]
            scores = scores_all[i]

            boxes[:, 0::2] = paddle.clip(
                boxes[:, 0::2], min=0, max=img_wh[i][0]) / scale_factor_wh[i][0]
            boxes[:, 1::2] = paddle.clip(
                boxes[:, 1::2], min=0, max=img_wh[i][1]) / scale_factor_wh[i][1]
            boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
                boxes[:, 3] - boxes[:, 1]).numpy()

            keep = (boxes_w > 1.) & (boxes_h > 1.)

            if (keep.sum() == 0):
                bboxes = paddle.zeros([1, 6]).astype("float32")
            else:
                boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
                classes = paddle.to_tensor(classes.numpy()[keep]).astype(
                    "float32").unsqueeze(-1)
                scores = paddle.to_tensor(scores.numpy()[keep]).astype(
                    "float32").unsqueeze(-1)

                bboxes = paddle.concat([classes, scores, boxes], axis=-1)

            boxes_final.append(bboxes)
            bbox_num[i] = bboxes.shape[0]

        bbox_pred = paddle.concat(boxes_final)
        return bbox_pred, bbox_num
M
Mark Ma 已提交
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727


def nms(dets, thresh):
    """Apply classic DPM-style greedy NMS."""
    if dets.shape[0] == 0:
        return dets[[], :]
    scores = dets[:, 0]
    x1 = dets[:, 1]
    y1 = dets[:, 2]
    x2 = dets[:, 3]
    y2 = dets[:, 4]

    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    order = scores.argsort()[::-1]

    ndets = dets.shape[0]
    suppressed = np.zeros((ndets), dtype=np.int)

    # nominal indices
    # _i, _j
    # sorted indices
    # i, j
    # temp variables for box i's (the box currently under consideration)
    # ix1, iy1, ix2, iy2, iarea

    # variables for computing overlap with box j (lower scoring box)
    # xx1, yy1, xx2, yy2
    # w, h
    # inter, ovr

    for _i in range(ndets):
        i = order[_i]
        if suppressed[i] == 1:
            continue
        ix1 = x1[i]
        iy1 = y1[i]
        ix2 = x2[i]
        iy2 = y2[i]
        iarea = areas[i]
        for _j in range(_i + 1, ndets):
            j = order[_j]
            if suppressed[j] == 1:
                continue
            xx1 = max(ix1, x1[j])
            yy1 = max(iy1, y1[j])
            xx2 = min(ix2, x2[j])
            yy2 = min(iy2, y2[j])
            w = max(0.0, xx2 - xx1 + 1)
            h = max(0.0, yy2 - yy1 + 1)
            inter = w * h
            ovr = inter / (iarea + areas[j] - inter)
            if ovr >= thresh:
                suppressed[j] = 1
    keep = np.where(suppressed == 0)[0]
    dets = dets[keep, :]
    return dets