infer.py 2.5 KB
Newer Older
J
jerrywgz 已提交
1 2 3
import os
import time
import numpy as np
J
jerrywgz 已提交
4
from eval_helper import *
J
jerrywgz 已提交
5 6 7 8 9 10 11 12 13 14
import paddle
import paddle.fluid as fluid
import reader
from utility import print_arguments, parse_args
import models.model_builder as model_builder
import models.resnet as resnet
import json
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval, Params
from config import cfg
15
from roidbs import DatasetPath
J
jerrywgz 已提交
16 17 18 19


def infer():

20 21
    data_path = DatasetPath('val')
    test_list = data_path.get_file_list()
J
jerrywgz 已提交
22

23
    cocoGt = COCO(test_list)
J
jerrywgz 已提交
24
    num_id_to_cat_id_map = {i + 1: v for i, v in enumerate(cocoGt.getCatIds())}
J
jerrywgz 已提交
25 26 27 28 29 30 31 32 33
    category_ids = cocoGt.getCatIds()
    label_list = {
        item['id']: item['name']
        for item in cocoGt.loadCats(category_ids)
    }
    label_list[0] = ['background']
    image_shape = [3, cfg.TEST.max_size, cfg.TEST.max_size]
    class_nums = cfg.class_num

34
    model = model_builder.RCNN(
J
jerrywgz 已提交
35 36 37 38 39
        add_conv_body_func=resnet.add_ResNet50_conv4_body,
        add_roi_box_head_func=resnet.add_ResNet_roi_conv5_head,
        use_pyreader=False,
        is_train=False)
    model.build_model(image_shape)
40
    pred_boxes = model.eval_bbox_out()
J
jerrywgz 已提交
41 42
    if cfg.MASK_ON:
        masks = model.eval_mask_out()
J
jerrywgz 已提交
43 44 45 46 47 48 49 50 51 52 53 54
    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    # yapf: disable
    if cfg.pretrained_model:
        def if_exist(var):
            return os.path.exists(os.path.join(cfg.pretrained_model, var.name))
        fluid.io.load_vars(exe, cfg.pretrained_model, predicate=if_exist)
    # yapf: enable
    infer_reader = reader.infer()
    feeder = fluid.DataFeeder(place=place, feed_list=model.feeds())

    dts_res = []
J
jerrywgz 已提交
55 56
    segms_res = []
    if cfg.MASK_ON:
57
        fetch_list = [pred_boxes, masks]
J
jerrywgz 已提交
58
    else:
59
        fetch_list = [pred_boxes]
J
jerrywgz 已提交
60 61
    data = next(infer_reader())
    im_info = [data[0][1]]
J
jerrywgz 已提交
62 63 64
    result = exe.run(fetch_list=[v.name for v in fetch_list],
                     feed=feeder.feed(data),
                     return_numpy=False)
65
    pred_boxes_v = result[0]
J
jerrywgz 已提交
66
    if cfg.MASK_ON:
67
        masks_v = result[1]
J
jerrywgz 已提交
68 69
    new_lod = pred_boxes_v.lod()
    nmsed_out = pred_boxes_v
J
jerrywgz 已提交
70
    path = os.path.join(cfg.image_path, cfg.image_name)
J
jerrywgz 已提交
71 72 73 74 75 76 77
    image = None
    if cfg.MASK_ON:
        segms_out = segm_results(nmsed_out, masks_v, im_info)
        image = draw_mask_on_image(path, segms_out, cfg.draw_threshold)

    draw_bounding_box_on_image(path, nmsed_out, cfg.draw_threshold, label_list,
                               num_id_to_cat_id_map, image)
J
jerrywgz 已提交
78 79 80 81 82 83


if __name__ == '__main__':
    args = parse_args()
    print_arguments(args)
    infer()