eval_coco_map.py 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#  Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
J
jerrywgz 已提交
18 19 20
import os
import time
import numpy as np
J
jerrywgz 已提交
21
from eval_helper import *
J
jerrywgz 已提交
22 23 24
import paddle
import paddle.fluid as fluid
import reader
J
jerrywgz 已提交
25
from utility import print_arguments, parse_args
J
jerrywgz 已提交
26 27 28 29 30
import models.model_builder as model_builder
import models.resnet as resnet
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
J
jerrywgz 已提交
31
from config import cfg
32
from roidbs import DatasetPath
J
jerrywgz 已提交
33 34


J
jerrywgz 已提交
35
def eval():
36 37 38

    data_path = DatasetPath('val')
    test_list = data_path.get_file_list()
J
jerrywgz 已提交
39

J
jerrywgz 已提交
40
    image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
J
jerrywgz 已提交
41
    class_nums = cfg.class_num
J
jerrywgz 已提交
42 43 44
    devices = os.getenv("CUDA_VISIBLE_DEVICES") or ""
    devices_num = len(devices.split(","))
    total_batch_size = devices_num * cfg.TRAIN.im_per_batch
45
    cocoGt = COCO(test_list)
J
jerrywgz 已提交
46
    num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
J
jerrywgz 已提交
47 48 49 50 51 52 53
    category_ids = cocoGt.getCatIds()
    label_list = {
        item['id']: item['name']
        for item in cocoGt.loadCats(category_ids)
    }
    label_list[0] = ['background']

54
    model = model_builder.RCNN(
J
jerrywgz 已提交
55 56 57
        add_conv_body_func=resnet.add_ResNet50_conv4_body,
        add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head,
        use_pyreader=False,
J
jerrywgz 已提交
58
        is_train=False)
J
jerrywgz 已提交
59
    model.build_model(image_shape)
60
    pred_boxes = model.eval_bbox_out()
J
jerrywgz 已提交
61 62
    if cfg.MASK_ON:
        masks = model.eval_mask_out()
J
jerrywgz 已提交
63
    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
J
jerrywgz 已提交
64
    exe = fluid.Executor(place)
J
jerrywgz 已提交
65
    exe.run(fluid.default_startup_program())
J
jerrywgz 已提交
66
    # yapf: disable
J
jerrywgz 已提交
67
    if cfg.pretrained_model:
J
jerrywgz 已提交
68
        def if_exist(var):
J
jerrywgz 已提交
69 70
            return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
        fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
J
jerrywgz 已提交
71

J
jerrywgz 已提交
72
    # yapf: enable
J
jerrywgz 已提交
73
    test_reader = reader.test(total_batch_size)
J
jerrywgz 已提交
74 75 76
    feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())

    dts_res = []
J
jerrywgz 已提交
77 78
    segms_res = []
    if cfg.MASK_ON:
79
        fetch_list = [pred_boxes, masks]
J
jerrywgz 已提交
80
    else:
81 82
        fetch_list = [pred_boxes]
    eval_start = time.time()
J
jerrywgz 已提交
83
    for batch_id, batch_data in enumerate(test_reader()):
J
jerrywgz 已提交
84 85
        start = time.time()
        im_info = []
J
jerrywgz 已提交
86 87
        for data in batch_data:
            im_info.append(data[1])
J
jerrywgz 已提交
88 89 90 91
        result = exe.run(fetch_list=[v.name for v in fetch_list],
                         feed=feeder.feed(batch_data),
                         return_numpy=False)

92
        pred_boxes_v = result[0]
J
jerrywgz 已提交
93
        if cfg.MASK_ON:
94
            masks_v = result[1]
J
jerrywgz 已提交
95 96 97

        new_lod = pred_boxes_v.lod()
        nmsed_out = pred_boxes_v
J
jerrywgz 已提交
98

J
jerrywgz 已提交
99 100 101 102 103 104 105
        dts_res += get_dt_res(total_batch_size, new_lod[0], nmsed_out,
                              batch_data, num_id_to_cat_id_map)

        if cfg.MASK_ON and np.array(masks_v).shape != (1, 1):
            segms_out = segm_results(nmsed_out, masks_v, im_info)
            segms_res += get_segms_res(total_batch_size, new_lod[0], segms_out,
                                       batch_data, num_id_to_cat_id_map)
J
jerrywgz 已提交
106 107
        end = time.time()
        print('batch id: {}, time: {}'.format(batch_id, end - start))
108 109 110
    eval_end = time.time()
    total_time = eval_end - eval_start
    print('average time of eval is: {}'.format(total_time / (batch_id + 1)))
J
jerrywgz 已提交
111
    with open("detection_bbox_result.json", 'w') as outfile:
J
jerrywgz 已提交
112
        json.dump(dts_res, outfile)
J
jerrywgz 已提交
113 114
    print("start evaluate bbox using coco api")
    cocoDt = cocoGt.loadRes("detection_bbox_result.json")
J
jerrywgz 已提交
115 116 117 118 119
    cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
    cocoEval.evaluate()
    cocoEval.accumulate()
    cocoEval.summarize()

J
jerrywgz 已提交
120 121 122 123 124 125 126 127 128 129
    if cfg.MASK_ON:
        with open("detection_segms_result.json", 'w') as outfile:
            json.dump(segms_res, outfile)
        print("start evaluate mask using coco api")
        cocoDt = cocoGt.loadRes("detection_segms_result.json")
        cocoEval = COCOeval(cocoGt, cocoDt, 'segm')
        cocoEval.evaluate()
        cocoEval.accumulate()
        cocoEval.summarize()

J
jerrywgz 已提交
130 131

if __name__ == '__main__':
J
jerrywgz 已提交
132
    args = parse_args()
J
jerrywgz 已提交
133
    print_arguments(args)
J
jerrywgz 已提交
134
    eval()