post_process.py 7.4 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
import six
import os
import numpy as np
from numba import jit
from .bbox import delta2bbox, clip_bbox, expand_bbox, nms
import pycocotools.mask as mask_util
import cv2


def bbox_post_process(bboxes,
                      bbox_prob,
                      bbox_deltas,
                      im_shape,
                      scale_factor,
                      keep_top_k=100,
                      score_thresh=0.05,
                      nms_thresh=0.5,
                      class_nums=81,
                      bbox_reg_weights=[0.1, 0.1, 0.2, 0.2],
                      with_background=True):
    bbox, bbox_num = bboxes
    new_bbox = [[] for _ in range(len(bbox_num))]
    new_bbox_num = []
    st_num = 0
    end_num = 0
    for i in range(len(bbox_num)):
        box_num = bbox_num[i]
        end_num += box_num

        boxes = bbox[st_num:end_num, :]  # bbox 
        boxes = boxes / scale_factor[i]  # scale
        bbox_delta = bbox_deltas[st_num:end_num, :, :]  # bbox delta 
        bbox_delta = np.reshape(bbox_delta, (box_num, -1))
        # step1: decode 
        boxes = delta2bbox(bbox_delta, boxes, bbox_reg_weights)

        # step2: clip 
        boxes = clip_bbox(boxes, im_shape[i][:2] / scale_factor[i])
        # step3: nms 
        cls_boxes = [[] for _ in range(class_nums)]
        scores_n = bbox_prob[st_num:end_num, :]
        for j in range(with_background, class_nums):
            inds = np.where(scores_n[:, j] > score_thresh)[0]
            scores_j = scores_n[inds, j]
            rois_j = boxes[inds, j * 4:(j + 1) * 4]
            dets_j = np.hstack((scores_j[:, np.newaxis], rois_j)).astype(
                np.float32, copy=False)
            keep = nms(dets_j, nms_thresh)
            nms_dets = dets_j[keep, :]
            #add labels
            label = np.array([j for _ in range(len(keep))])
            nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
                np.float32, copy=False)
            cls_boxes[j] = nms_dets

        st_num += box_num

        # Limit to max_per_image detections **over all classes**
        image_scores = np.hstack(
            [cls_boxes[j][:, 1] for j in range(with_background, class_nums)])
        if len(image_scores) > keep_top_k:
            image_thresh = np.sort(image_scores)[-keep_top_k]
            for j in range(with_background, class_nums):
                keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
                cls_boxes[j] = cls_boxes[j][keep, :]
        new_bbox_n = np.vstack(
            [cls_boxes[j] for j in range(with_background, class_nums)])
        new_bbox[i] = new_bbox_n
        new_bbox_num.append(len(new_bbox_n))
    new_bbox = np.vstack([new_bbox[k] for k in range(len(bbox_num))])
    new_bbox_num = np.array(new_bbox_num).astype('int32')
    return new_bbox, new_bbox_num


@jit
def mask_post_process(det_res,
                      im_shape,
                      scale_factor,
                      resolution=14,
                      binary_thresh=0.5):
    bbox = det_res['bbox']
    bbox_num = det_res['bbox_num']
    masks = det_res['mask']
    if masks.shape[0] == 0:
        return masks
    M = resolution
    scale = (M + 2.0) / M
    boxes = bbox[:, 2:]
    labels = bbox[:, 0]
    segms_results = [[] for _ in range(len(bbox_num))]
    sum = 0
    st_num = 0
    end_num = 0
    for i in range(len(bbox_num)):
        length = bbox_num[i]
        end_num += length
        cls_segms = []
        boxes_n = boxes[st_num:end_num]
        labels_n = labels[st_num:end_num]
        masks_n = masks[st_num:end_num]

        im_h = int(round(im_shape[i][0] / scale_factor[i, 0]))
        im_w = int(round(im_shape[i][1] / scale_factor[i, 0]))
        boxes_n = expand_bbox(boxes_n, scale)
        boxes_n = boxes_n.astype(np.int32)
        padded_mask = np.zeros((M + 2, M + 2), dtype=np.float32)
        for j in range(len(boxes_n)):
            class_id = int(labels_n[j])
            padded_mask[1:-1, 1:-1] = masks_n[j, class_id, :, :]

            ref_box = boxes_n[j, :]
            w = ref_box[2] - ref_box[0] + 1
            h = ref_box[3] - ref_box[1] + 1
            w = np.maximum(w, 1)
            h = np.maximum(h, 1)

            mask = cv2.resize(padded_mask, (w, h))
            mask = np.array(mask > binary_thresh, dtype=np.uint8)
            im_mask = np.zeros((im_h, im_w), dtype=np.uint8)

            x_0 = max(ref_box[0], 0)
            x_1 = min(ref_box[2] + 1, im_w)
            y_0 = max(ref_box[1], 0)
            y_1 = min(ref_box[3] + 1, im_h)
            im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - ref_box[1]):(y_1 - ref_box[
                1]), (x_0 - ref_box[0]):(x_1 - ref_box[0])]
            sum += im_mask.sum()
            rle = mask_util.encode(
                np.array(
                    im_mask[:, :, np.newaxis], order='F'))[0]
            cls_segms.append(rle)
        segms_results[i] = np.array(cls_segms)[:, np.newaxis]
        st_num += length
    segms_results = np.vstack([segms_results[k] for k in range(len(bbox_num))])
    bboxes = np.hstack([segms_results, bbox])
    return bboxes[:, :3]


@jit
def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
    det_res = []
    k = 0
    for i in range(len(bbox_nums)):
        cur_image_id = int(image_id[i][0])
        det_nums = bbox_nums[i]
        for j in range(det_nums):
            dt = bboxes[k]
            k = k + 1
            num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
F
Feng Ni 已提交
150 151
            if num_id < 0:
                continue
Q
qingqing01 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
            category_id = num_id_to_cat_id_map[num_id]
            w = xmax - xmin + 1
            h = ymax - ymin + 1
            bbox = [xmin, ymin, w, h]
            dt_res = {
                'image_id': cur_image_id,
                'category_id': category_id,
                'bbox': bbox,
                'score': score
            }
            det_res.append(dt_res)
    return det_res


@jit
def get_seg_res(masks, mask_nums, image_id, num_id_to_cat_id_map):
    seg_res = []
    k = 0
    for i in range(len(mask_nums)):
        cur_image_id = int(image_id[i][0])
        det_nums = mask_nums[i]
        for j in range(det_nums):
            dt = masks[k]
            k = k + 1
            sg, num_id, score = dt.tolist()
            cat_id = num_id_to_cat_id_map[num_id]
            if six.PY3:
                if 'counts' in sg:
                    sg['counts'] = sg['counts'].decode("utf8")
            sg_res = {
                'image_id': cur_image_id,
                'category_id': cat_id,
                'segmentation': sg,
                'score': score
            }
            seg_res.append(sg_res)
    return seg_res
G
Guanghua Yu 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217


def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map):
    import pycocotools.mask as mask_util
    segm_res = []
    # for each batch
    segms = results['segm'].astype(np.uint8)
    clsid_labels = results['cate_label']
    clsid_scores = results['cate_score']
    lengths = segms.shape[0]
    im_id = int(image_id[0][0])
    if lengths == 0 or segms is None:
        return None
    # for each sample
    for i in range(lengths - 1):
        clsid = int(clsid_labels[i]) + 1
        catid = num_id_to_cat_id_map[clsid]
        score = float(clsid_scores[i])
        mask = segms[i]
        segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
        segm['counts'] = segm['counts'].decode('utf8')
        coco_res = {
            'image_id': im_id,
            'category_id': catid,
            'segmentation': segm,
            'score': score
        }
        segm_res.append(coco_res)
    return segm_res