post_process.py 6.4 KB
Newer Older
F
FDInSky 已提交
1 2 3 4 5
import six
import os
import numpy as np
from numba import jit
from .bbox import delta2bbox, clip_bbox, expand_bbox, nms
6 7
import pycocotools.mask as mask_util
import cv2
F
FDInSky 已提交
8 9 10


def bbox_post_process(bboxes,
11
                      bbox_prob,
F
FDInSky 已提交
12
                      bbox_deltas,
13 14
                      im_shape,
                      scale_factor,
F
FDInSky 已提交
15 16 17 18
                      keep_top_k=100,
                      score_thresh=0.05,
                      nms_thresh=0.5,
                      class_nums=81,
19 20 21 22 23
                      bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
                      with_background=True):
    bbox, bbox_num = bboxes
    new_bbox = [[] for _ in range(len(bbox_num))]
    new_bbox_num = []
24 25
    st_num = 0
    end_num = 0
26 27 28 29 30
    for i in range(len(bbox_num)):
        box_num = bbox_num[i]
        end_num += box_num

        boxes = bbox[st_num:end_num, :]  # bbox 
31
        boxes = boxes / scale_factor[i]  # scale
32 33
        bbox_delta = bbox_deltas[st_num:end_num, :, :]  # bbox delta 
        bbox_delta = np.reshape(bbox_delta, (box_num, -1))
34
        # step1: decode 
35
        boxes = delta2bbox(bbox_delta, boxes, bbox_reg_weights)
36 37

        # step2: clip 
38
        boxes = clip_bbox(boxes, im_shape[i][:2] / scale_factor[i])
39
        # step3: nms 
F
FDInSky 已提交
40
        cls_boxes = [[] for _ in range(class_nums)]
41 42
        scores_n = bbox_prob[st_num:end_num, :]
        for j in range(with_background, class_nums):
F
FDInSky 已提交
43 44
            inds = np.where(scores_n[:, j] > score_thresh)[0]
            scores_j = scores_n[inds, j]
45
            rois_j = boxes[inds, j * 4:(j + 1) * 4]
F
FDInSky 已提交
46 47 48 49 50 51 52 53 54 55
            dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
                np.float32, copy=False)
            keep = nms(dets_j, nms_thresh)
            nms_dets = dets_j[keep, :]
            #add labels
            label = np.array([j for _ in range(len(keep))])
            nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
                np.float32, copy=False)
            cls_boxes[j] = nms_dets

56
        st_num += box_num
57

F
FDInSky 已提交
58 59
        # Limit to max_per_image detections **over all classes**
        image_scores = np.hstack(
60
            [cls_boxes[j][:, 1] for j in range(with_background, class_nums)])
F
FDInSky 已提交
61 62
        if len(image_scores) > keep_top_k:
            image_thresh = np.sort(image_scores)[-keep_top_k]
63
            for j in range(with_background, class_nums):
F
FDInSky 已提交
64 65
                keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
                cls_boxes[j] = cls_boxes[j][keep, :]
66 67 68 69 70 71 72
        new_bbox_n = np.vstack(
            [cls_boxes[j] for j in range(with_background, class_nums)])
        new_bbox[i] = new_bbox_n
        new_bbox_num.append(len(new_bbox_n))
    new_bbox = np.vstack([new_bbox[k] for k in range(len(bbox_num))])
    new_bbox_num = np.array(new_bbox_num).astype('int32')
    return new_bbox, new_bbox_num
F
FDInSky 已提交
73 74 75


@jit
W
wangguanzhong 已提交
76
def mask_post_process(det_res,
77 78 79 80
                      im_shape,
                      scale_factor,
                      resolution=14,
                      binary_thresh=0.5):
W
wangguanzhong 已提交
81 82 83
    bbox = det_res['bbox']
    bbox_num = det_res['bbox_num']
    masks = det_res['mask']
84 85 86 87 88 89
    if masks.shape[0] == 0:
        return masks
    M = resolution
    scale = (M + 2.0) / M
    boxes = bbox[:, 2:]
    labels = bbox[:, 0]
W
wangguanzhong 已提交
90
    segms_results = [[] for _ in range(len(bbox_num))]
F
FDInSky 已提交
91
    sum = 0
92 93
    st_num = 0
    end_num = 0
W
wangguanzhong 已提交
94 95 96
    for i in range(len(bbox_num)):
        length = bbox_num[i]
        end_num += length
F
FDInSky 已提交
97
        cls_segms = []
98 99 100
        boxes_n = boxes[st_num:end_num]
        labels_n = labels[st_num:end_num]
        masks_n = masks[st_num:end_num]
W
wangguanzhong 已提交
101 102 103

        im_h = int(round(im_shape[i][0] / scale_factor[i, 0]))
        im_w = int(round(im_shape[i][1] / scale_factor[i, 0]))
104
        boxes_n = expand_bbox(boxes_n, scale)
F
FDInSky 已提交
105 106
        boxes_n = boxes_n.astype(np.int32)
        padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32)
107
        for j in range(len(boxes_n)):
F
FDInSky 已提交
108 109 110 111 112 113 114 115 116 117
            class_id = int(labels_n[j])
            padded_mask[1:-1, 1:-1] = masks_n[j, class_id, :, :]

            ref_box = boxes_n[j, :]
            w = ref_box[2] - ref_box[0] + 1
            h = ref_box[3] - ref_box[1] + 1
            w = np.maximum(w, 1)
            h = np.maximum(h, 1)

            mask = cv2.resize(padded_mask, (w, h))
118
            mask = np.array(mask > binary_thresh, dtype=np.uint8)
F
FDInSky 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132
            im_mask = np.zeros((im_h, im_w), dtype=np.uint8)

            x_0 = max(ref_box[0], 0)
            x_1 = min(ref_box[2] + 1, im_w)
            y_0 = max(ref_box[1], 0)
            y_1 = min(ref_box[3] + 1, im_h)
            im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[
                1]), (x_0 - ref_box[0]):(x_1 - ref_box[0])]
            sum += im_mask.sum()
            rle = mask_util.encode(
                np.array(
                    im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms.append(rle)
        segms_results[i] = np.array(cls_segms)[:, np.newaxis]
W
wangguanzhong 已提交
133 134
        st_num += length
    segms_results = np.vstack([segms_results[k] for k in range(len(bbox_num))])
135
    bboxes = np.hstack([segms_results, bbox])
F
FDInSky 已提交
136 137 138 139
    return bboxes[:, :3]


@jit
140
def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
F
FDInSky 已提交
141 142
    det_res = []
    k = 0
143
    for i in range(len(bbox_nums)):
144
        cur_image_id = int(image_id[i][0])
145 146 147
        det_nums = bbox_nums[i]
        for j in range(det_nums):
            dt = bboxes[k]
F
FDInSky 已提交
148 149 150 151 152 153 154
            k = k + 1
            num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
            category_id = num_id_to_cat_id_map[num_id]
            w = xmax - xmin + 1
            h = ymax - ymin + 1
            bbox = [xmin, ymin, w, h]
            dt_res = {
155
                'image_id': cur_image_id,
F
FDInSky 已提交
156 157 158 159 160 161 162 163 164
                'category_id': category_id,
                'bbox': bbox,
                'score': score
            }
            det_res.append(dt_res)
    return det_res


@jit
165
def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
F
FDInSky 已提交
166 167
    seg_res = []
    k = 0
168
    for i in range(len(mask_nums)):
169
        cur_image_id = int(image_id[i][0])
170 171 172
        det_nums = mask_nums[i]
        for j in range(det_nums):
            dt = masks[k]
F
FDInSky 已提交
173 174 175 176 177 178 179
            k = k + 1
            sg, num_id, score = dt.tolist()
            cat_id = num_id_to_cat_id_map[num_id]
            if six.PY3:
                if 'counts' in sg:
                    sg['counts'] = sg['counts'].decode("utf8")
            sg_res = {
180
                'image_id': cur_image_id,
F
FDInSky 已提交
181 182 183 184 185 186
                'category_id': cat_id,
                'segmentation': sg,
                'score': score
            }
            seg_res.append(sg_res)
    return seg_res