eval_utils.py 2.7 KB
Newer Older
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

5 6
import os

7

8
def json_eval_results(metric, json_directory=None, dataset=None):
9 10 11 12 13
    """
    cocoapi eval with already exists proposal.json, bbox.json or mask.json
    """
    assert metric == 'COCO'
    from ppdet.utils.coco_eval import cocoapi_eval
14
    anno_file = dataset.get_anno()
15 16
    json_file_list = ['proposal.json', 'bbox.json', 'mask.json']
    if json_directory:
17 18 19
        assert os.path.exists(
            json_directory), "The json directory:{} does not exist".format(
                json_directory)
20 21 22 23 24 25 26 27 28
        for k, v in enumerate(json_file_list):
            json_file_list[k] = os.path.join(str(json_directory), v)

    coco_eval_style = ['proposal', 'bbox', 'segm']
    for i, v_json in enumerate(json_file_list):
        if os.path.exists(v_json):
            cocoapi_eval(v_json, coco_eval_style[i], anno_file=anno_file)
        else:
            logger.info("{} not exists!".format(v_json))
F
FDInSky 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43


def coco_eval_results(outs_res=None,
                      include_mask=False,
                      batch_size=1,
                      dataset=None):
    print("start evaluate bbox using coco api")
    import io
    import six
    import json
    from pycocotools.coco import COCO
    from pycocotools.cocoeval import COCOeval
    from ppdet.py_op.post_process import get_det_res, get_seg_res
    anno_file = os.path.join(dataset.dataset_dir, dataset.anno_path)
    cocoGt = COCO(anno_file)
44 45 46 47
    catid = {
        i + dataset.with_background: v
        for i, v in enumerate(cocoGt.getCatIds())
    }
F
FDInSky 已提交
48 49 50 51 52

    if outs_res is not None and len(outs_res) > 0:
        det_res = []
        for outs in outs_res:
            det_res += get_det_res(outs['bbox_nums'], outs['bbox'],
53
                                   outs['im_id'], catid, batch_size)
F
FDInSky 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78

        with io.open("bbox_eval.json", 'w') as outfile:
            encode_func = unicode if six.PY2 else str
            outfile.write(encode_func(json.dumps(det_res)))

        cocoDt = cocoGt.loadRes("bbox_eval.json")
        cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()

    if outs_res is not None and len(outs_res) > 0 and include_mask:
        seg_res = []
        for outs in outs_res:
            seg_res += get_seg_res(outs['bbox_nums'], outs['mask'],
                                   outs['im_id'], catid, batch_size)

        with io.open("mask_eval.json", 'w') as outfile:
            encode_func = unicode if six.PY2 else str
            outfile.write(encode_func(json.dumps(seg_res)))

        cocoSg = cocoGt.loadRes("mask_eval.json")
        cocoEval = COCOeval(cocoGt, cocoSg, 'bbox')
        cocoEval.evaluate()
        cocoEval.accumulate()