infer.py 5.2 KB
Newer Older
1 2 3 4 5 6 7
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
X
Xingyuan Bu 已提交
8
from PIL import ImageFont
9 10 11 12

import paddle
import paddle.fluid as fluid
import reader
13
from mobilenet_ssd import build_mobilenet_ssd
14 15 16 17 18 19 20 21 22 23
from utility import add_arguments, print_arguments

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset',          str,   'pascalvoc',    "coco and pascalvoc.")
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
add_arg('image_path',       str,   '',        "The image used to inference and visualize.")
add_arg('model_dir',        str,   '',     "The model path.")
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
24
add_arg('confs_threshold',  float, 0.5,    "Confidence threshold to draw bbox.")
25 26 27 28 29 30 31 32 33 34 35 36
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
# yapf: enable


def infer(args, data_args, image_path, model_dir):
    image_shape = [3, data_args.resize_h, data_args.resize_w]
    if 'coco' in data_args.dataset:
        num_classes = 91
X
Xingyuan Bu 已提交
37 38 39 40 41 42 43 44 45 46 47
        # cocoapi
        from pycocotools.coco import COCO
        from pycocotools.cocoeval import COCOeval
        label_fpath = os.path.join(data_dir, label_file)
        coco = COCO(label_fpath)
        category_ids = coco.getCatIds()
        label_list = {
            item['id']: item['name']
            for item in coco.loadCats(category_ids)
        }
        label_list[0] = ['background']
48 49
    elif 'pascalvoc' in data_args.dataset:
        num_classes = 21
X
Xingyuan Bu 已提交
50
        label_list = data_args.label_list
51 52

    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
53 54
    locs, confs, box, box_var = build_mobilenet_ssd(image, num_classes,
                                                    image_shape)
55 56 57 58 59 60 61 62 63 64 65 66 67 68
    nmsed_out = fluid.layers.detection_output(
        locs, confs, box, box_var, nms_threshold=args.nms_threshold)

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    # yapf: disable
    if model_dir:
        def if_exist(var):
            return os.path.exists(os.path.join(model_dir, var.name))
        fluid.io.load_vars(exe, model_dir, predicate=if_exist)
    # yapf: enable
    infer_reader = reader.infer(data_args, image_path)
    feeder = fluid.DataFeeder(place=place, feed_list=[image])

X
Xingyuan Bu 已提交
69
    data = infer_reader()
X
Xingyuan Bu 已提交
70 71 72 73 74 75 76 77

    # switch network to test mode (i.e. batch norm test mode)
    test_program = fluid.default_main_program().clone(for_test=True)
    nmsed_out_v, = exe.run(test_program,
                           feed=feeder.feed([[data]]),
                           fetch_list=[nmsed_out],
                           return_numpy=False)
    nmsed_out_v = np.array(nmsed_out_v)
X
Xingyuan Bu 已提交
78
    draw_bounding_box_on_image(image_path, nmsed_out_v, args.confs_threshold,
X
Xingyuan Bu 已提交
79
                               label_list)
80 81


X
Xingyuan Bu 已提交
82 83
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold,
                               label_list):
84 85 86 87 88
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)
    im_width, im_height = image.size

    for dt in nms_out:
89
        if dt[1] < confs_threshold:
90
            continue
91
        category_id = dt[0]
92
        bbox = dt[2:]
93
        xmin, ymin, xmax, ymax = clip_bbox(dt[2:])
94 95 96 97 98 99 100
        (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                      ymin * im_height, ymax * im_height)
        draw.line(
            [(left, top), (left, bottom), (right, bottom), (right, top),
             (left, top)],
            width=4,
            fill='red')
X
Xingyuan Bu 已提交
101 102
        if image.mode == 'RGB':
            draw.text((left, top), label_list[int(category_id)], (255, 255, 0))
103 104 105 106 107
    image_name = image_path.split('/')[-1]
    print("image with bbox drawed saved as {}".format(image_name))
    image.save(image_name)


108 109 110 111 112 113 114 115
def clip_bbox(bbox):
    xmin = max(min(bbox[0], 1.), 0.)
    ymin = max(min(bbox[1], 1.), 0.)
    xmax = max(min(bbox[2], 1.), 0.)
    ymax = max(min(bbox[3], 1.), 0.)
    return xmin, ymin, xmax, ymax


116 117 118 119
if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

X
Xingyuan Bu 已提交
120 121 122 123 124 125 126 127 128 129
    data_dir = 'data/pascalvoc'
    label_file = 'label_list'

    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))
    if 'coco' in args.dataset:
        data_dir = 'data/coco'
        label_file = 'annotations/instances_val2014.json'

130 131
    data_args = reader.Settings(
        dataset=args.dataset,
X
Xingyuan Bu 已提交
132 133
        data_dir=data_dir,
        label_file=label_file,
134 135 136 137 138
        resize_h=args.resize_h,
        resize_w=args.resize_w,
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
        apply_distort=False,
        apply_expand=False,
B
Bai Yifan 已提交
139
        ap_version='')
140 141 142 143 144
    infer(
        args,
        data_args=data_args,
        image_path=args.image_path,
        model_dir=args.model_dir)