infer.py 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

M
Manuel Garcia 已提交
19 20 21
import os
import sys

22 23 24 25
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
    sys.path.append(parent_path)
26

Q
qingqing01 已提交
27
import glob
28
import numpy as np
29
import six
30
from PIL import Image, ImageOps
31 32 33 34 35 36 37 38

from paddle import fluid

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

K
Kaipeng Deng 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
try:
    from ppdet.core.workspace import load_config, merge_config, create

    from ppdet.utils.eval_utils import parse_fetches
    from ppdet.utils.cli import ArgsParser
    from ppdet.utils.check import check_gpu, check_version, check_config, enable_static_mode
    from ppdet.utils.visualizer import visualize_results
    import ppdet.utils.checkpoint as checkpoint

    from ppdet.data.reader import create_reader
except ImportError as e:
    if sys.argv[0].find('static') >= 0:
        logger.error("Importing ppdet failed when running static model "
                     "with error: {}\n"
                     "please try:\n"
                     "\t1. run static model under PaddleDetection/static "
                     "directory\n"
                     "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
                     "dynamic version firstly.".format(e))
        sys.exit(-1)
    else:
        raise e

62 63 64 65 66 67 68

def get_save_image_name(output_dir, image_path):
    """
    Get save image name from source image path.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
K
Kaipeng Deng 已提交
69
    image_name = os.path.split(image_path)[-1]
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    name, ext = os.path.splitext(image_name)
    return os.path.join(output_dir, "{}".format(name)) + ext


def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
        "--infer_img or --infer_dir should be set"
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
W
wangguanzhong 已提交
87
        return [infer_img]
88

W
wangguanzhong 已提交
89
    images = set()
90 91 92 93 94 95
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
W
wangguanzhong 已提交
96 97
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)
98 99 100 101 102 103 104 105 106 107 108

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    logger.info("Found {} inference images in total.".format(len(images)))

    return images


def main():
    cfg = load_config(FLAGS.config)

    merge_config(FLAGS.opt)
109
    check_config(cfg)
110 111
    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
W
wangguanzhong 已提交
112 113
    # check if paddlepaddle version is satisfied
    check_version()
114

115 116
    main_arch = cfg.architecture

117
    dataset = cfg.TestReader['dataset']
118 119

    test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
120
    dataset.set_images(test_images)
121 122 123 124 125 126 127 128 129 130

    place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    model = create(main_arch)

    startup_prog = fluid.Program()
    infer_prog = fluid.Program()
    with fluid.program_guard(infer_prog, startup_prog):
        with fluid.unique_name.guard():
131 132 133
            inputs_def = cfg['TestReader']['inputs_def']
            inputs_def['iterable'] = True
            feed_vars, loader = model.build_inputs(**inputs_def)
134 135 136
            test_fetches = model.test(feed_vars)
    infer_prog = infer_prog.clone(True)

137
    reader = create_reader(cfg.TestReader, devices_num=1)
W
wangguanzhong 已提交
138
    loader.set_sample_list_generator(reader, place)
139 140 141

    exe.run(startup_prog)
    if cfg.weights:
142
        checkpoint.load_params(exe, infer_prog, cfg.weights)
143 144

    # parse infer fetches
145
    assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \
146
            "unknown metric type {}".format(cfg.metric)
147
    extra_keys = []
148
    if cfg['metric'] in ['COCO', 'OID']:
149
        extra_keys = ['im_info', 'im_id', 'im_shape']
150
    if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE':
151
        extra_keys = ['im_id', 'im_shape']
152 153 154 155
    keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)

    # parse dataset category
    if cfg.metric == 'COCO':
G
Guanghua Yu 已提交
156
        from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info
157 158
    if cfg.metric == 'OID':
        from ppdet.utils.oid_eval import bbox2out, get_category_info
159 160
    if cfg.metric == "VOC":
        from ppdet.utils.voc_eval import bbox2out, get_category_info
161
    if cfg.metric == "WIDERFACE":
162
        from ppdet.utils.widerface_eval_utils import bbox2out, lmk2out, get_category_info
163

164 165 166 167
    anno_file = dataset.get_anno()
    with_background = dataset.with_background
    use_default_label = dataset.use_default_label

168 169 170
    clsid2catid, catid2name = get_category_info(anno_file, with_background,
                                                use_default_label)

171 172 173 174 175 176
    # whether output bbox is normalized in model output layer
    is_bbox_normalized = False
    if hasattr(model, 'is_bbox_normalized') and \
            callable(model.is_bbox_normalized):
        is_bbox_normalized = model.is_bbox_normalized()

走神的阿圆's avatar
走神的阿圆 已提交
177 178
    # use VisualDL to log image
    if FLAGS.use_vdl:
179
        assert six.PY3, "VisualDL requires Python >= 3.5"
走神的阿圆's avatar
走神的阿圆 已提交
180 181 182 183
        from visualdl import LogWriter
        vdl_writer = LogWriter(FLAGS.vdl_log_dir)
        vdl_image_step = 0
        vdl_image_frame = 0  # each frame can display ten pictures at most.
184

185
    imid2path = dataset.get_imid2path()
W
wangguanzhong 已提交
186
    for iter_id, data in enumerate(loader()):
187
        outs = exe.run(infer_prog,
W
wangguanzhong 已提交
188
                       feed=data,
189 190 191 192 193 194 195
                       fetch_list=values,
                       return_numpy=False)
        res = {
            k: (np.array(v), v.recursive_sequence_lengths())
            for k, v in zip(keys, outs)
        }
        logger.info('Infer iter {}'.format(iter_id))
W
wangguanzhong 已提交
196 197
        if 'TTFNet' in cfg.architecture:
            res['bbox'][1].append([len(res['bbox'][0])])
198 199 200 201
        if 'CornerNet' in cfg.architecture:
            from ppdet.utils.post_process import corner_post_process
            post_config = getattr(cfg, 'PostProcess', None)
            corner_post_process(res, post_config, cfg.num_classes)
202 203 204

        bbox_results = None
        mask_results = None
G
Guanghua Yu 已提交
205
        segm_results = None
206
        lmk_results = None
207 208 209 210 211
        if 'bbox' in res:
            bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
        if 'mask' in res:
            mask_results = mask2out([res], clsid2catid,
                                    model.mask_head.resolution)
G
Guanghua Yu 已提交
212 213
        if 'segm' in res:
            segm_results = segm2out([res], clsid2catid)
214 215
        if 'landmark' in res:
            lmk_results = lmk2out([res], is_bbox_normalized)
216 217 218 219 220 221

        # visualize result
        im_ids = res['im_id'][0]
        for im_id in im_ids:
            image_path = imid2path[int(im_id)]
            image = Image.open(image_path).convert('RGB')
222
            image = ImageOps.exif_transpose(image)
223

走神的阿圆's avatar
走神的阿圆 已提交
224 225
            # use VisualDL to log original image
            if FLAGS.use_vdl:
226
                original_image_np = np.array(image)
走神的阿圆's avatar
走神的阿圆 已提交
227 228 229
                vdl_writer.add_image(
                    "original/frame_{}".format(vdl_image_frame),
                    original_image_np, vdl_image_step)
230

231
            image = visualize_results(image,
J
jerrywgz 已提交
232 233
                                      int(im_id), catid2name,
                                      FLAGS.draw_threshold, bbox_results,
G
Guanghua Yu 已提交
234
                                      mask_results, segm_results, lmk_results)
W
wangguanzhong 已提交
235

走神的阿圆's avatar
走神的阿圆 已提交
236 237
            # use VisualDL to log image with bbox
            if FLAGS.use_vdl:
238
                infer_image_np = np.array(image)
走神的阿圆's avatar
走神的阿圆 已提交
239 240 241 242 243 244
                vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame),
                                     infer_image_np, vdl_image_step)
                vdl_image_step += 1
                if vdl_image_step % 10 == 0:
                    vdl_image_step = 0
                    vdl_image_frame += 1
245

246 247
            save_name = get_save_image_name(FLAGS.output_dir, image_path)
            logger.info("Detection bbox results save in {}".format(save_name))
J
jerrywgz 已提交
248
            image.save(save_name, quality=95)
249 250 251


if __name__ == '__main__':
252
    enable_static_mode()
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    parser = ArgsParser()
    parser.add_argument(
        "--infer_dir",
        type=str,
        default=None,
        help="Directory for images to perform inference on.")
    parser.add_argument(
        "--infer_img",
        type=str,
        default=None,
        help="Image path, has higher priority over --infer_dir")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Directory for storing the output visualization files.")
J
jerrywgz 已提交
269 270 271 272 273
    parser.add_argument(
        "--draw_threshold",
        type=float,
        default=0.5,
        help="Threshold to reserve the result for visualization.")
274
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
275
        "--use_vdl",
276 277
        type=bool,
        default=False,
走神的阿圆's avatar
走神的阿圆 已提交
278
        help="whether to record the data to VisualDL.")
279
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
280
        '--vdl_log_dir',
281
        type=str,
走神的阿圆's avatar
走神的阿圆 已提交
282 283
        default="vdl_log_dir/image",
        help='VisualDL logging directory for image.')
284
    FLAGS = parser.parse_args()
W
wangguanzhong 已提交
285
    main()