post_process.py 2.9 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6
import six
import os
import numpy as np
import cv2


W
wangxinxin08 已提交
7 8 9 10 11 12 13
def get_det_res(bboxes,
                scores,
                labels,
                bbox_nums,
                image_id,
                label_to_cat_id_map,
                bias=0):
Q
qingqing01 已提交
14 15 16 17 18 19
    det_res = []
    k = 0
    for i in range(len(bbox_nums)):
        cur_image_id = int(image_id[i][0])
        det_nums = bbox_nums[i]
        for j in range(det_nums):
20 21 22 23
            box = bboxes[k]
            score = float(scores[k])
            label = int(labels[k])
            if label < 0: continue
Q
qingqing01 已提交
24
            k = k + 1
25 26
            xmin, ymin, xmax, ymax = box.tolist()
            category_id = label_to_cat_id_map[label]
W
wangxinxin08 已提交
27 28
            w = xmax - xmin + bias
            h = ymax - ymin + bias
Q
qingqing01 已提交
29 30 31 32 33 34 35 36 37 38 39
            bbox = [xmin, ymin, w, h]
            dt_res = {
                'image_id': cur_image_id,
                'category_id': category_id,
                'bbox': bbox,
                'score': score
            }
            det_res.append(dt_res)
    return det_res


40 41 42
def get_seg_res(masks, scores, labels, mask_nums, image_id,
                label_to_cat_id_map):
    import pycocotools.mask as mask_util
Q
qingqing01 已提交
43 44 45 46 47 48
    seg_res = []
    k = 0
    for i in range(len(mask_nums)):
        cur_image_id = int(image_id[i][0])
        det_nums = mask_nums[i]
        for j in range(det_nums):
49 50 51
            mask = masks[k]
            score = float(scores[k])
            label = int(labels[k])
Q
qingqing01 已提交
52
            k = k + 1
53 54 55 56
            cat_id = label_to_cat_id_map[label]
            rle = mask_util.encode(
                np.array(
                    mask[:, :, None], order="F", dtype="uint8"))[0]
Q
qingqing01 已提交
57
            if six.PY3:
58 59
                if 'counts' in rle:
                    rle['counts'] = rle['counts'].decode("utf8")
Q
qingqing01 已提交
60 61 62
            sg_res = {
                'image_id': cur_image_id,
                'category_id': cat_id,
63
                'segmentation': rle,
Q
qingqing01 已提交
64 65 66 67
                'score': score
            }
            seg_res.append(sg_res)
    return seg_res
G
Guanghua Yu 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96


def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map):
    import pycocotools.mask as mask_util
    segm_res = []
    # for each batch
    segms = results['segm'].astype(np.uint8)
    clsid_labels = results['cate_label']
    clsid_scores = results['cate_score']
    lengths = segms.shape[0]
    im_id = int(image_id[0][0])
    if lengths == 0 or segms is None:
        return None
    # for each sample
    for i in range(lengths - 1):
        clsid = int(clsid_labels[i]) + 1
        catid = num_id_to_cat_id_map[clsid]
        score = float(clsid_scores[i])
        mask = segms[i]
        segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
        segm['counts'] = segm['counts'].decode('utf8')
        coco_res = {
            'image_id': im_id,
            'category_id': catid,
            'segmentation': segm,
            'score': score
        }
        segm_res.append(coco_res)
    return segm_res