post_process.py 2.8 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6
import six
import os
import numpy as np
import cv2


7 8
def get_det_res(bboxes, scores, labels, bbox_nums, image_id,
                label_to_cat_id_map):
Q
qingqing01 已提交
9 10 11 12 13 14
    det_res = []
    k = 0
    for i in range(len(bbox_nums)):
        cur_image_id = int(image_id[i][0])
        det_nums = bbox_nums[i]
        for j in range(det_nums):
15 16 17 18
            box = bboxes[k]
            score = float(scores[k])
            label = int(labels[k])
            if label < 0: continue
Q
qingqing01 已提交
19
            k = k + 1
20 21 22 23
            xmin, ymin, xmax, ymax = box.tolist()
            category_id = label_to_cat_id_map[label]
            w = xmax - xmin
            h = ymax - ymin
Q
qingqing01 已提交
24 25 26 27 28 29 30 31 32 33 34
            bbox = [xmin, ymin, w, h]
            dt_res = {
                'image_id': cur_image_id,
                'category_id': category_id,
                'bbox': bbox,
                'score': score
            }
            det_res.append(dt_res)
    return det_res


35 36 37
def get_seg_res(masks, scores, labels, mask_nums, image_id,
                label_to_cat_id_map):
    import pycocotools.mask as mask_util
Q
qingqing01 已提交
38 39 40 41 42 43
    seg_res = []
    k = 0
    for i in range(len(mask_nums)):
        cur_image_id = int(image_id[i][0])
        det_nums = mask_nums[i]
        for j in range(det_nums):
44 45 46
            mask = masks[k]
            score = float(scores[k])
            label = int(labels[k])
Q
qingqing01 已提交
47
            k = k + 1
48 49 50 51
            cat_id = label_to_cat_id_map[label]
            rle = mask_util.encode(
                np.array(
                    mask[:, :, None], order="F", dtype="uint8"))[0]
Q
qingqing01 已提交
52
            if six.PY3:
53 54
                if 'counts' in rle:
                    rle['counts'] = rle['counts'].decode("utf8")
Q
qingqing01 已提交
55 56 57
            sg_res = {
                'image_id': cur_image_id,
                'category_id': cat_id,
58
                'segmentation': rle,
Q
qingqing01 已提交
59 60 61 62
                'score': score
            }
            seg_res.append(sg_res)
    return seg_res
G
Guanghua Yu 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91


def get_solov2_segm_res(results, image_id, num_id_to_cat_id_map):
    import pycocotools.mask as mask_util
    segm_res = []
    # for each batch
    segms = results['segm'].astype(np.uint8)
    clsid_labels = results['cate_label']
    clsid_scores = results['cate_score']
    lengths = segms.shape[0]
    im_id = int(image_id[0][0])
    if lengths == 0 or segms is None:
        return None
    # for each sample
    for i in range(lengths - 1):
        clsid = int(clsid_labels[i]) + 1
        catid = num_id_to_cat_id_map[clsid]
        score = float(clsid_scores[i])
        mask = segms[i]
        segm = mask_util.encode(np.array(mask[:, :, np.newaxis], order='F'))[0]
        segm['counts'] = segm['counts'].decode('utf8')
        coco_res = {
            'image_id': im_id,
            'category_id': catid,
            'segmentation': segm,
            'score': score
        }
        segm_res.append(coco_res)
    return segm_res