post_process.py 7.1 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
6
from ppdet.modeling.bbox_utils import nonempty_bbox
Q
qingqing01 已提交
7 8 9 10 11
from . import ops


@register
class BBoxPostProcess(object):
12
    __shared__ = ['num_classes']
Q
qingqing01 已提交
13 14
    __inject__ = ['decode', 'nms']

15
    def __init__(self, num_classes=80, decode=None, nms=None):
Q
qingqing01 已提交
16
        super(BBoxPostProcess, self).__init__()
17
        self.num_classes = num_classes
Q
qingqing01 已提交
18 19 20
        self.decode = decode
        self.nms = nms

21 22 23 24 25 26 27 28 29 30 31 32
    def __call__(self, head_out, rois, im_shape, scale_factor):
        """
        Decode the bbox and do NMS if needed. 

        Returns:
            bbox_pred(Tensor): The output is the prediction with shape [N, 6]
                               including labels, scores and bboxes. The size of 
                               bboxes are corresponding to the input image and 
                               the bboxes may be used in other brunch.
            bbox_num(Tensor): The number of prediction of each batch with shape
                              [N, 6].
        """
F
Feng Ni 已提交
33 34
        if self.nms is not None:
            bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
35
            bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
F
Feng Ni 已提交
36 37 38
        else:
            bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
                                              scale_factor)
39 40 41 42
        if bbox_pred.shape[0] == 0:
            bbox_pred = paddle.to_tensor(
                np.array(
                    [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
Q
qingqing01 已提交
43 44
        return bbox_pred, bbox_num

45 46 47 48 49 50 51 52 53 54 55 56
    def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
        """
        Rescale, clip and filter the bbox from the output of NMS to 
        get final prediction.

        Args:
            bboxes(Tensor): The output of __call__ with shape [N, 6]
        Returns:
            bbox_pred(Tensor): The output is the prediction with shape [N, 6]
                               including labels, scores and bboxes. The size of
                               bboxes are corresponding to the original image.
        """
57

58 59 60 61 62 63 64 65
        origin_shape = paddle.floor(im_shape / scale_factor + 0.5)

        origin_shape_list = []
        scale_factor_list = []
        # scale_factor: scale_y, scale_x
        for i in range(bbox_num.shape[0]):
            expand_shape = paddle.expand(origin_shape[i:i + 1, :],
                                         [bbox_num[i], 2])
G
Guanghua Yu 已提交
66
            scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
67 68
            scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
            expand_scale = paddle.expand(scale, [bbox_num[i], 4])
G
Guanghua Yu 已提交
69 70 71
            # TODO: Because paddle.expand transform error when dygraph
            # to static, use reshape to avoid mistakes.
            expand_scale = paddle.reshape(expand_scale, [bbox_num[i], 4])
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
            origin_shape_list.append(expand_shape)
            scale_factor_list.append(expand_scale)

        self.origin_shape_list = paddle.concat(origin_shape_list)
        scale_factor_list = paddle.concat(scale_factor_list)

        # bboxes: [N, 6], label, score, bbox
        pred_label = bboxes[:, 0:1]
        pred_score = bboxes[:, 1:2]
        pred_bbox = bboxes[:, 2:]
        # rescale bbox to original image
        scaled_bbox = pred_bbox / scale_factor_list
        origin_h = self.origin_shape_list[:, 0]
        origin_w = self.origin_shape_list[:, 1]
        zeros = paddle.zeros_like(origin_h)
        # clip bbox to [0, original_size]
        x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
        y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
        x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
        y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
        pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
        # filter empty bbox
        keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
        keep_mask = paddle.unsqueeze(keep_mask, [1])
        pred_label = paddle.where(keep_mask, pred_label,
                                  paddle.ones_like(pred_label) * -1)
        pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
        return pred_result

    def get_origin_shape(self, ):
        return self.origin_shape_list

Q
qingqing01 已提交
104 105 106

@register
class MaskPostProcess(object):
107
    def __init__(self, binary_thresh=0.5):
Q
qingqing01 已提交
108 109 110
        super(MaskPostProcess, self).__init__()
        self.binary_thresh = binary_thresh

111 112 113 114 115 116 117 118 119 120 121 122 123 124
    def paste_mask(self, masks, boxes, im_h, im_w):
        # paste each mask on image
        x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
        masks = paddle.unsqueeze(masks, [0, 1])
        img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
        img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
        img_y = (img_y - y0) / (y1 - y0) * 2 - 1
        img_x = (img_x - x0) / (x1 - x0) * 2 - 1
        img_x = paddle.unsqueeze(img_x, [1])
        img_y = paddle.unsqueeze(img_y, [2])
        N = boxes.shape[0]

        gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
        gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
G
Guanghua Yu 已提交
125 126 127 128
        # TODO: Because paddle.expand transform error when dygraph
        # to static, use reshape to avoid mistakes.
        gx = paddle.reshape(gx, [N, img_y.shape[1], img_x.shape[2]])
        gy = paddle.reshape(gy, [N, img_y.shape[1], img_x.shape[2]])
129 130 131 132 133 134 135 136 137
        grid = paddle.stack([gx, gy], axis=3)
        img_masks = F.grid_sample(masks, grid, align_corners=False)
        return img_masks[:, 0]

    def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
        """
        Paste the mask prediction to the original image.
        """
        num_mask = mask_out.shape[0]
G
Guanghua Yu 已提交
138 139
        origin_shape = paddle.cast(origin_shape, 'int32')
        # TODO: support bs > 1 and mask output dtype is bool
140
        pred_result = paddle.zeros(
G
Guanghua Yu 已提交
141 142 143 144
            [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
        if bboxes.shape[0] == 0:
            return pred_result

145
        # TODO: optimize chunk paste
G
Guanghua Yu 已提交
146
        pred_result = []
147
        for i in range(bboxes.shape[0]):
G
Guanghua Yu 已提交
148
            im_h, im_w = origin_shape[i][0], origin_shape[i][1]
149 150 151
            pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
                                        im_w)
            pred_mask = pred_mask >= self.binary_thresh
G
Guanghua Yu 已提交
152 153 154
            pred_mask = paddle.cast(pred_mask, 'int32')
            pred_result.append(pred_mask)
        pred_result = paddle.concat(pred_result)
155
        return pred_result
F
Feng Ni 已提交
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172


@register
class FCOSPostProcess(object):
    __inject__ = ['decode', 'nms']

    def __init__(self, decode=None, nms=None):
        super(FCOSPostProcess, self).__init__()
        self.decode = decode
        self.nms = nms

    def __call__(self, fcos_head_outs, scale_factor):
        locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
        bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
                                    centerness, scale_factor)
        bbox_pred, bbox_num, _ = self.nms(bboxes, score)
        return bbox_pred, bbox_num