eval_coco_map.py 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
J
jerrywgz 已提交
18 19 20
import os
import time
import numpy as np
J
jerrywgz 已提交
21
from eval_helper import *
J
jerrywgz 已提交
22 23 24
import paddle
import paddle.fluid as fluid
import reader
J
jerrywgz 已提交
25
from utility import print_arguments, parse_args
J
jerrywgz 已提交
26 27 28 29 30
import models.model_builder as model_builder
import models.resnet as resnet
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
J
jerrywgz 已提交
31
from config import cfg
J
jerrywgz 已提交
32 33


J
jerrywgz 已提交
34
def eval():
J
jerrywgz 已提交
35
    if '2014' in cfg.dataset:
J
jerrywgz 已提交
36
        test_list = 'annotations/instances_val2014.json'
J
jerrywgz 已提交
37
    elif '2017' in cfg.dataset:
J
jerrywgz 已提交
38 39
        test_list = 'annotations/instances_val2017.json'

J
jerrywgz 已提交
40
    image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
J
jerrywgz 已提交
41
    class_nums = cfg.class_num
J
jerrywgz 已提交
42 43 44
    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))
    total_batch_size = devices_num * cfg.TRAIN.im_per_batch
J
jerrywgz 已提交
45
    cocoGt = COCO(os.path.join(cfg.data_dir, test_list))
J
jerrywgz 已提交
46
    num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
J
jerrywgz 已提交
47 48 49 50 51 52 53 54 55 56 57
    category_ids = cocoGt.getCatIds()
    label_list = {
        item['id']: item['name']
        for item in cocoGt.loadCats(category_ids)
    }
    label_list[0] = ['background']

    model = model_builder.FasterRCNN(
        add_conv_body_func=resnet.add_ResNet50_conv4_body,
        add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head,
        use_pyreader=False,
J
jerrywgz 已提交
58
        is_train=False)
J
jerrywgz 已提交
59
    model.build_model(image_shape)
J
jerrywgz 已提交
60 61 62 63
    rpn_rois, confs, locs = model.eval_bbox_out()
    pred_boxes = model.eval()
    if cfg.MASK_ON:
        masks = model.eval_mask_out()
J
jerrywgz 已提交
64
    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
J
jerrywgz 已提交
65
    exe = fluid.Executor(place)
J
jerrywgz 已提交
66
    exe.run(fluid.default_startup_program())
J
jerrywgz 已提交
67
    # yapf: disable
J
jerrywgz 已提交
68
    if cfg.pretrained_model:
J
jerrywgz 已提交
69
        def if_exist(var):
J
jerrywgz 已提交
70 71
            return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
        fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
J
jerrywgz 已提交
72

J
jerrywgz 已提交
73
    # yapf: enable
J
jerrywgz 已提交
74
    test_reader = reader.test(total_batch_size)
J
jerrywgz 已提交
75 76 77
    feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())

    dts_res = []
J
jerrywgz 已提交
78 79 80 81 82
    segms_res = []
    if cfg.MASK_ON:
        fetch_list = [rpn_rois, confs, locs, pred_boxes, masks]
    else:
        fetch_list = [rpn_rois, confs, locs]
J
jerrywgz 已提交
83
    for batch_id, batch_data in enumerate(test_reader()):
J
jerrywgz 已提交
84 85
        start = time.time()
        im_info = []
J
jerrywgz 已提交
86 87
        for data in batch_data:
            im_info.append(data[1])
J
jerrywgz 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100
        result = exe.run(fetch_list=[v.name for v in fetch_list],
                         feed=feeder.feed(batch_data),
                         return_numpy=False)

        rpn_rois_v = result[0]
        confs_v = result[1]
        locs_v = result[2]
        if cfg.MASK_ON:
            pred_boxes_v = result[3]
            masks_v = result[4]

        new_lod = pred_boxes_v.lod()
        nmsed_out = pred_boxes_v
J
jerrywgz 已提交
101

J
jerrywgz 已提交
102 103 104 105 106 107 108
        dts_res += get_dt_res(total_batch_size, new_lod[0], nmsed_out,
                              batch_data, num_id_to_cat_id_map)

        if cfg.MASK_ON and np.array(masks_v).shape != (1, 1):
            segms_out = segm_results(nmsed_out, masks_v, im_info)
            segms_res += get_segms_res(total_batch_size, new_lod[0], segms_out,
                                       batch_data, num_id_to_cat_id_map)
J
jerrywgz 已提交
109 110
        end = time.time()
        print('batch id: {}, time: {}'.format(batch_id, end - start))
J
jerrywgz 已提交
111
    with open("detection_bbox_result.json", 'w') as outfile:
J
jerrywgz 已提交
112
        json.dump(dts_res, outfile)
J
jerrywgz 已提交
113 114
    print("start evaluate bbox using coco api")
    cocoDt = cocoGt.loadRes("detection_bbox_result.json")
J
jerrywgz 已提交
115 116 117 118 119
    cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()

J
jerrywgz 已提交
120 121 122 123 124 125 126 127 128 129
    if cfg.MASK_ON:
        with open("detection_segms_result.json", 'w') as outfile:
            json.dump(segms_res, outfile)
        print("start evaluate mask using coco api")
        cocoDt = cocoGt.loadRes("detection_segms_result.json")
        cocoEval = COCOeval(cocoGt, cocoDt, 'segm')
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()

J
jerrywgz 已提交
130 131

if __name__ == '__main__':
J
jerrywgz 已提交
132
    args = parse_args()
J
jerrywgz 已提交
133
    print_arguments(args)
J
jerrywgz 已提交
134
    eval()