infer.py 5.2 KB
Newer Older
1 2 3 4 5 6 7
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
X
Xingyuan Bu 已提交
8
from PIL import ImageFont
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23

import paddle
import paddle.fluid as fluid
import reader
from mobilenet_ssd import mobile_net
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset',          str,   'pascalvoc',    "coco and pascalvoc.")
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
add_arg('image_path',       str,   '',        "The image used to inference and visualize.")
add_arg('model_dir',        str,   '',     "The model path.")
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
24
add_arg('confs_threshold',  float, 0.5,    "Confidence threshold to draw bbox.")
25 26 27 28 29 30 31 32 33 34 35 36
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
# yapf: enable


def infer(args, data_args, image_path, model_dir):
    image_shape = [3, data_args.resize_h, data_args.resize_w]
    if 'coco' in data_args.dataset:
        num_classes = 91
X
Xingyuan Bu 已提交
37 38 39 40 41 42 43 44 45 46 47
        # cocoapi
        from pycocotools.coco import COCO
        from pycocotools.cocoeval import COCOeval
        label_fpath = os.path.join(data_dir, label_file)
        coco = COCO(label_fpath)
        category_ids = coco.getCatIds()
        label_list = {
            item['id']: item['name']
            for item in coco.loadCats(category_ids)
        }
        label_list[0] = ['background']
48 49
    elif 'pascalvoc' in data_args.dataset:
        num_classes = 21
X
Xingyuan Bu 已提交
50
        label_list = data_args.label_list
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67

    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
    locs, confs, box, box_var = mobile_net(num_classes, image, image_shape)
    nmsed_out = fluid.layers.detection_output(
        locs, confs, box, box_var, nms_threshold=args.nms_threshold)

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    # yapf: disable
    if model_dir:
        def if_exist(var):
            return os.path.exists(os.path.join(model_dir, var.name))
        fluid.io.load_vars(exe, model_dir, predicate=if_exist)
    # yapf: enable
    infer_reader = reader.infer(data_args, image_path)
    feeder = fluid.DataFeeder(place=place, feed_list=[image])

X
Xingyuan Bu 已提交
68
    data = infer_reader()
X
Xingyuan Bu 已提交
69 70 71 72 73 74 75 76

    # switch network to test mode (i.e. batch norm test mode)
    test_program = fluid.default_main_program().clone(for_test=True)
    nmsed_out_v, = exe.run(test_program,
                           feed=feeder.feed([[data]]),
                           fetch_list=[nmsed_out],
                           return_numpy=False)
    nmsed_out_v = np.array(nmsed_out_v)
X
Xingyuan Bu 已提交
77
    draw_bounding_box_on_image(image_path, nmsed_out_v, args.confs_threshold,
X
Xingyuan Bu 已提交
78
                               label_list)
79 80


X
Xingyuan Bu 已提交
81 82
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold,
                               label_list):
83 84 85 86 87
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)
    im_width, im_height = image.size

    for dt in nms_out:
88
        if dt[1] < confs_threshold:
89
            continue
90
        category_id = dt[0]
91
        bbox = dt[2:]
92
        xmin, ymin, xmax, ymax = clip_bbox(dt[2:])
93 94 95 96 97 98 99
        (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                      ymin * im_height, ymax * im_height)
        draw.line(
            [(left, top), (left, bottom), (right, bottom), (right, top),
             (left, top)],
            width=4,
            fill='red')
X
Xingyuan Bu 已提交
100 101
        if image.mode == 'RGB':
            draw.text((left, top), label_list[int(category_id)], (255, 255, 0))
102 103 104 105 106
    image_name = image_path.split('/')[-1]
    print("image with bbox drawed saved as {}".format(image_name))
    image.save(image_name)


107 108 109 110 111 112 113 114
def clip_bbox(bbox):
    xmin = max(min(bbox[0], 1.), 0.)
    ymin = max(min(bbox[1], 1.), 0.)
    xmax = max(min(bbox[2], 1.), 0.)
    ymax = max(min(bbox[3], 1.), 0.)
    return xmin, ymin, xmax, ymax


115 116 117 118
if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

X
Xingyuan Bu 已提交
119 120 121 122 123 124 125 126 127 128
    data_dir = 'data/pascalvoc'
    label_file = 'label_list'

    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))
    if 'coco' in args.dataset:
        data_dir = 'data/coco'
        label_file = 'annotations/instances_val2014.json'

129 130
    data_args = reader.Settings(
        dataset=args.dataset,
X
Xingyuan Bu 已提交
131 132
        data_dir=data_dir,
        label_file=label_file,
133 134 135 136 137
        resize_h=args.resize_h,
        resize_w=args.resize_w,
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
        apply_distort=False,
        apply_expand=False,
B
Bai Yifan 已提交
138
        ap_version='')
139 140 141 142 143
    infer(
        args,
        data_args=data_args,
        image_path=args.image_path,
        model_dir=args.model_dir)