infer.py 5.9 KB
Newer Older
X
xiaoting 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15 16 17 18 19 20
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
X
Xingyuan Bu 已提交
21
from PIL import ImageFont
22 23 24 25

import paddle
import paddle.fluid as fluid
import reader
26
from mobilenet_ssd import build_mobilenet_ssd
L
LielinJiang 已提交
27
from utility import add_arguments, print_arguments, check_cuda
28 29 30 31 32 33 34 35 36

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('dataset',          str,   'pascalvoc',    "coco and pascalvoc.")
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
add_arg('image_path',       str,   '',        "The image used to inference and visualize.")
add_arg('model_dir',        str,   '',     "The model path.")
add_arg('nms_threshold',    float, 0.45,   "NMS threshold.")
X
Xingyuan Bu 已提交
37
add_arg('confs_threshold',  float, 0.5,    "Confidence threshold to draw bbox.")
38 39 40 41 42 43 44 45 46 47 48 49
add_arg('resize_h',         int,   300,    "The resized image height.")
add_arg('resize_w',         int,   300,    "The resized image height.")
add_arg('mean_value_B',     float, 127.5,  "Mean value for B channel which will be subtracted.")  #123.68
add_arg('mean_value_G',     float, 127.5,  "Mean value for G channel which will be subtracted.")  #116.78
add_arg('mean_value_R',     float, 127.5,  "Mean value for R channel which will be subtracted.")  #103.94
# yapf: enable


def infer(args, data_args, image_path, model_dir):
    image_shape = [3, data_args.resize_h, data_args.resize_w]
    if 'coco' in data_args.dataset:
        num_classes = 91
X
Xingyuan Bu 已提交
50 51 52 53 54 55 56 57 58 59 60
        # cocoapi
        from pycocotools.coco import COCO
        from pycocotools.cocoeval import COCOeval
        label_fpath = os.path.join(data_dir, label_file)
        coco = COCO(label_fpath)
        category_ids = coco.getCatIds()
        label_list = {
            item['id']: item['name']
            for item in coco.loadCats(category_ids)
        }
        label_list[0] = ['background']
61 62
    elif 'pascalvoc' in data_args.dataset:
        num_classes = 21
X
Xingyuan Bu 已提交
63
        label_list = data_args.label_list
64 65

    image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
66 67
    locs, confs, box, box_var = build_mobilenet_ssd(image, num_classes,
                                                    image_shape)
68 69 70 71 72 73 74 75 76 77 78 79 80 81
    nmsed_out = fluid.layers.detection_output(
        locs, confs, box, box_var, nms_threshold=args.nms_threshold)

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    # yapf: disable
    if model_dir:
        def if_exist(var):
            return os.path.exists(os.path.join(model_dir, var.name))
        fluid.io.load_vars(exe, model_dir, predicate=if_exist)
    # yapf: enable
    infer_reader = reader.infer(data_args, image_path)
    feeder = fluid.DataFeeder(place=place, feed_list=[image])

X
Xingyuan Bu 已提交
82
    data = infer_reader()
X
Xingyuan Bu 已提交
83 84 85 86 87 88 89 90

    # switch network to test mode (i.e. batch norm test mode)
    test_program = fluid.default_main_program().clone(for_test=True)
    nmsed_out_v, = exe.run(test_program,
                           feed=feeder.feed([[data]]),
                           fetch_list=[nmsed_out],
                           return_numpy=False)
    nmsed_out_v = np.array(nmsed_out_v)
X
Xingyuan Bu 已提交
91
    draw_bounding_box_on_image(image_path, nmsed_out_v, args.confs_threshold,
X
Xingyuan Bu 已提交
92
                               label_list)
93 94


X
Xingyuan Bu 已提交
95 96
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold,
                               label_list):
97 98 99 100 101
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)
    im_width, im_height = image.size

    for dt in nms_out:
102
        if dt[1] < confs_threshold:
103
            continue
104
        category_id = dt[0]
105
        bbox = dt[2:]
106
        xmin, ymin, xmax, ymax = clip_bbox(dt[2:])
107 108 109 110 111 112 113
        (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                      ymin * im_height, ymax * im_height)
        draw.line(
            [(left, top), (left, bottom), (right, bottom), (right, top),
             (left, top)],
            width=4,
            fill='red')
X
Xingyuan Bu 已提交
114 115
        if image.mode == 'RGB':
            draw.text((left, top), label_list[int(category_id)], (255, 255, 0))
116 117 118 119 120
    image_name = image_path.split('/')[-1]
    print("image with bbox drawed saved as {}".format(image_name))
    image.save(image_name)


121 122 123 124 125 126 127 128
def clip_bbox(bbox):
    xmin = max(min(bbox[0], 1.), 0.)
    ymin = max(min(bbox[1], 1.), 0.)
    xmax = max(min(bbox[2], 1.), 0.)
    ymax = max(min(bbox[3], 1.), 0.)
    return xmin, ymin, xmax, ymax


129 130 131 132
if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

L
LielinJiang 已提交
133 134
    check_cuda(args.use_gpu)

X
Xingyuan Bu 已提交
135 136 137 138 139 140 141 142 143 144
    data_dir = 'data/pascalvoc'
    label_file = 'label_list'

    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))
    if 'coco' in args.dataset:
        data_dir = 'data/coco'
        label_file = 'annotations/instances_val2014.json'

145 146
    data_args = reader.Settings(
        dataset=args.dataset,
X
Xingyuan Bu 已提交
147 148
        data_dir=data_dir,
        label_file=label_file,
149 150 151 152 153
        resize_h=args.resize_h,
        resize_w=args.resize_w,
        mean_value=[args.mean_value_B, args.mean_value_G, args.mean_value_R],
        apply_distort=False,
        apply_expand=False,
B
Bai Yifan 已提交
154
        ap_version='')
155 156 157 158 159
    infer(
        args,
        data_args=data_args,
        image_path=args.image_path,
        model_dir=args.model_dir)