post_process.py 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
qingqing01 已提交
15 16 17 18 19
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
20
from ppdet.modeling.bbox_utils import nonempty_bbox
Q
qingqing01 已提交
21
from . import ops
W
wangguanzhong 已提交
22 23 24 25
try:
    from collections.abc import Sequence
except Exception:
    from collections import Sequence
Q
qingqing01 已提交
26 27 28 29


@register
class BBoxPostProcess(object):
30
    __shared__ = ['num_classes']
Q
qingqing01 已提交
31 32
    __inject__ = ['decode', 'nms']

33
    def __init__(self, num_classes=80, decode=None, nms=None):
Q
qingqing01 已提交
34
        super(BBoxPostProcess, self).__init__()
35
        self.num_classes = num_classes
Q
qingqing01 已提交
36 37 38
        self.decode = decode
        self.nms = nms

39 40 41 42 43 44 45 46 47 48 49 50
    def __call__(self, head_out, rois, im_shape, scale_factor):
        """
        Decode the bbox and do NMS if needed. 

        Returns:
            bbox_pred(Tensor): The output is the prediction with shape [N, 6]
                               including labels, scores and bboxes. The size of 
                               bboxes are corresponding to the input image and 
                               the bboxes may be used in other brunch.
            bbox_num(Tensor): The number of prediction of each batch with shape
                              [N, 6].
        """
F
Feng Ni 已提交
51 52
        if self.nms is not None:
            bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
53
            bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
F
Feng Ni 已提交
54 55 56
        else:
            bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
                                              scale_factor)
57 58 59 60
        if bbox_pred.shape[0] == 0:
            bbox_pred = paddle.to_tensor(
                np.array(
                    [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
61
            bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
Q
qingqing01 已提交
62 63
        return bbox_pred, bbox_num

64 65 66 67 68 69 70 71 72 73 74 75
    def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
        """
        Rescale, clip and filter the bbox from the output of NMS to 
        get final prediction.

        Args:
            bboxes(Tensor): The output of __call__ with shape [N, 6]
        Returns:
            bbox_pred(Tensor): The output is the prediction with shape [N, 6]
                               including labels, scores and bboxes. The size of
                               bboxes are corresponding to the original image.
        """
76

77 78 79 80 81 82 83 84
        origin_shape = paddle.floor(im_shape / scale_factor + 0.5)

        origin_shape_list = []
        scale_factor_list = []
        # scale_factor: scale_y, scale_x
        for i in range(bbox_num.shape[0]):
            expand_shape = paddle.expand(origin_shape[i:i + 1, :],
                                         [bbox_num[i], 2])
G
Guanghua Yu 已提交
85
            scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
86 87
            scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
            expand_scale = paddle.expand(scale, [bbox_num[i], 4])
G
Guanghua Yu 已提交
88 89 90
            # TODO: Because paddle.expand transform error when dygraph
            # to static, use reshape to avoid mistakes.
            expand_scale = paddle.reshape(expand_scale, [bbox_num[i], 4])
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
            origin_shape_list.append(expand_shape)
            scale_factor_list.append(expand_scale)

        self.origin_shape_list = paddle.concat(origin_shape_list)
        scale_factor_list = paddle.concat(scale_factor_list)

        # bboxes: [N, 6], label, score, bbox
        pred_label = bboxes[:, 0:1]
        pred_score = bboxes[:, 1:2]
        pred_bbox = bboxes[:, 2:]
        # rescale bbox to original image
        scaled_bbox = pred_bbox / scale_factor_list
        origin_h = self.origin_shape_list[:, 0]
        origin_w = self.origin_shape_list[:, 1]
        zeros = paddle.zeros_like(origin_h)
        # clip bbox to [0, original_size]
        x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
        y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
        x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
        y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
        pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
        # filter empty bbox
        keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
        keep_mask = paddle.unsqueeze(keep_mask, [1])
        pred_label = paddle.where(keep_mask, pred_label,
                                  paddle.ones_like(pred_label) * -1)
        pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
        return pred_result

    def get_origin_shape(self, ):
        return self.origin_shape_list

Q
qingqing01 已提交
123 124 125

@register
class MaskPostProcess(object):
126
    def __init__(self, binary_thresh=0.5):
Q
qingqing01 已提交
127 128 129
        super(MaskPostProcess, self).__init__()
        self.binary_thresh = binary_thresh

130 131 132 133 134 135 136 137 138 139 140 141 142 143
    def paste_mask(self, masks, boxes, im_h, im_w):
        # paste each mask on image
        x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
        masks = paddle.unsqueeze(masks, [0, 1])
        img_y = paddle.arange(0, im_h, dtype='float32') + 0.5
        img_x = paddle.arange(0, im_w, dtype='float32') + 0.5
        img_y = (img_y - y0) / (y1 - y0) * 2 - 1
        img_x = (img_x - x0) / (x1 - x0) * 2 - 1
        img_x = paddle.unsqueeze(img_x, [1])
        img_y = paddle.unsqueeze(img_y, [2])
        N = boxes.shape[0]

        gx = paddle.expand(img_x, [N, img_y.shape[1], img_x.shape[2]])
        gy = paddle.expand(img_y, [N, img_y.shape[1], img_x.shape[2]])
G
Guanghua Yu 已提交
144 145 146 147
        # TODO: Because paddle.expand transform error when dygraph
        # to static, use reshape to avoid mistakes.
        gx = paddle.reshape(gx, [N, img_y.shape[1], img_x.shape[2]])
        gy = paddle.reshape(gy, [N, img_y.shape[1], img_x.shape[2]])
148 149 150 151 152 153 154 155 156
        grid = paddle.stack([gx, gy], axis=3)
        img_masks = F.grid_sample(masks, grid, align_corners=False)
        return img_masks[:, 0]

    def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
        """
        Paste the mask prediction to the original image.
        """
        num_mask = mask_out.shape[0]
G
Guanghua Yu 已提交
157 158
        origin_shape = paddle.cast(origin_shape, 'int32')
        # TODO: support bs > 1 and mask output dtype is bool
159
        pred_result = paddle.zeros(
G
Guanghua Yu 已提交
160 161 162 163
            [num_mask, origin_shape[0][0], origin_shape[0][1]], dtype='int32')
        if bboxes.shape[0] == 0:
            return pred_result

164
        # TODO: optimize chunk paste
G
Guanghua Yu 已提交
165
        pred_result = []
166
        for i in range(bboxes.shape[0]):
G
Guanghua Yu 已提交
167
            im_h, im_w = origin_shape[i][0], origin_shape[i][1]
168 169 170
            pred_mask = self.paste_mask(mask_out[i], bboxes[i:i + 1, 2:], im_h,
                                        im_w)
            pred_mask = pred_mask >= self.binary_thresh
G
Guanghua Yu 已提交
171 172 173
            pred_mask = paddle.cast(pred_mask, 'int32')
            pred_result.append(pred_mask)
        pred_result = paddle.concat(pred_result)
174
        return pred_result
F
Feng Ni 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191


@register
class FCOSPostProcess(object):
    __inject__ = ['decode', 'nms']

    def __init__(self, decode=None, nms=None):
        super(FCOSPostProcess, self).__init__()
        self.decode = decode
        self.nms = nms

    def __call__(self, fcos_head_outs, scale_factor):
        locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
        bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
                                    centerness, scale_factor)
        bbox_pred, bbox_num, _ = self.nms(bboxes, score)
        return bbox_pred, bbox_num