eval_utils.py 3.6 KB
Newer Older
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

5
import os
W
wangguanzhong 已提交
6 7
import sys
import json
8
from ppdet.py_op.post_process import get_det_res, get_seg_res, mask_post_process
W
wangguanzhong 已提交
9 10
import logging
logger = logging.getLogger(__name__)
11

12

13
def json_eval_results(metric, json_directory=None, dataset=None):
14 15 16 17 18
    """
    cocoapi eval with already exists proposal.json, bbox.json or mask.json
    """
    assert metric == 'COCO'
    from ppdet.utils.coco_eval import cocoapi_eval
19
    anno_file = dataset.get_anno()
20 21
    json_file_list = ['proposal.json', 'bbox.json', 'mask.json']
    if json_directory:
22 23 24
        assert os.path.exists(
            json_directory), "The json directory:{} does not exist".format(
                json_directory)
25 26 27 28 29 30 31 32 33
        for k, v in enumerate(json_file_list):
            json_file_list[k] = os.path.join(str(json_directory), v)

    coco_eval_style = ['proposal', 'bbox', 'segm']
    for i, v_json in enumerate(json_file_list):
        if os.path.exists(v_json):
            cocoapi_eval(v_json, coco_eval_style[i], anno_file=anno_file)
        else:
            logger.info("{} not exists!".format(v_json))
F
FDInSky 已提交
34 35


36 37
def get_infer_results(outs_res, eval_type, catid, im_info,
                      mask_resolution=None):
W
wangguanzhong 已提交
38 39 40
    """
    Get result at the stage of inference.
    The output format is dictionary containing bbox or mask result.
F
FDInSky 已提交
41

W
wangguanzhong 已提交
42 43 44 45 46 47 48 49
    For example, bbox result is a list and each element contains
    image_id, category_id, bbox and score. 
    """
    if outs_res is None or len(outs_res) == 0:
        raise ValueError(
            'The number of valid detection result if zero. Please use reasonable model and check input data.'
        )
    infer_res = {}
F
FDInSky 已提交
50

W
wangguanzhong 已提交
51 52
    if 'bbox' in eval_type:
        box_res = []
53 54 55 56
        for i, outs in enumerate(outs_res):
            im_ids = im_info[i][2]
            box_res += get_det_res(outs['bbox'].numpy(),
                                   outs['bbox_num'].numpy(), im_ids, catid)
W
wangguanzhong 已提交
57
        infer_res['bbox'] = box_res
F
FDInSky 已提交
58

W
wangguanzhong 已提交
59
    if 'mask' in eval_type:
F
FDInSky 已提交
60
        seg_res = []
61 62 63 64 65 66 67 68 69 70 71
        # mask post process
        for i, outs in enumerate(outs_res):
            im_shape = im_info[i][0]
            scale_factor = im_info[i][1]
            im_ids = im_info[i][2]
            mask = mask_post_process(outs['bbox'].numpy(),
                                     outs['bbox_num'].numpy(),
                                     outs['mask'].numpy(), im_shape,
                                     scale_factor[0], mask_resolution)
            seg_res += get_seg_res(mask, outs['bbox_num'].numpy(), im_ids,
                                   catid)
W
wangguanzhong 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        infer_res['mask'] = seg_res

    return infer_res


def eval_results(res, metric, anno_file):
    """
    Evalute the inference result
    """
    eval_res = []
    if metric == 'COCO':
        from ppdet.utils.coco_eval import cocoapi_eval

        if 'bbox' in res:
            with open("bbox.json", 'w') as f:
                json.dump(res['bbox'], f)
                logger.info('The bbox result is saved to bbox.json.')

            bbox_stats = cocoapi_eval('bbox.json', 'bbox', anno_file=anno_file)
            eval_res.append(bbox_stats)
            sys.stdout.flush()
        if 'mask' in res:
            with open("mask.json", 'w') as f:
                json.dump(res['mask'], f)
                logger.info('The mask result is saved to mask.json.')
F
FDInSky 已提交
97

98
            seg_stats = cocoapi_eval('mask.json', 'segm', anno_file=anno_file)
W
wangguanzhong 已提交
99 100 101 102
            eval_res.append(seg_stats)
            sys.stdout.flush()
    else:
        raise NotImplemented("Only COCO metric is supported now.")
F
FDInSky 已提交
103

W
wangguanzhong 已提交
104
    return eval_res