infer.py 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

M
Manuel Garcia 已提交
19 20 21
import os
import sys

22 23 24 25
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
    sys.path.append(parent_path)
26

Q
qingqing01 已提交
27
import glob
28
import numpy as np
29
import six
30
from PIL import Image, ImageOps
31 32 33 34 35 36 37 38

from paddle import fluid

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

K
Kaipeng Deng 已提交
39 40 41 42 43
try:
    from ppdet.core.workspace import load_config, merge_config, create

    from ppdet.utils.eval_utils import parse_fetches
    from ppdet.utils.cli import ArgsParser
44
    from ppdet.utils.check import check_gpu, check_npu, check_version, check_config, enable_static_mode
K
Kaipeng Deng 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    from ppdet.utils.visualizer import visualize_results
    import ppdet.utils.checkpoint as checkpoint

    from ppdet.data.reader import create_reader
except ImportError as e:
    if sys.argv[0].find('static') >= 0:
        logger.error("Importing ppdet failed when running static model "
                     "with error: {}\n"
                     "please try:\n"
                     "\t1. run static model under PaddleDetection/static "
                     "directory\n"
                     "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
                     "dynamic version firstly.".format(e))
        sys.exit(-1)
    else:
        raise e

62 63 64 65 66 67 68

def get_save_image_name(output_dir, image_path):
    """
    Get save image name from source image path.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
K
Kaipeng Deng 已提交
69
    image_name = os.path.split(image_path)[-1]
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    name, ext = os.path.splitext(image_name)
    return os.path.join(output_dir, "{}".format(name)) + ext


def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
        "--infer_img or --infer_dir should be set"
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
W
wangguanzhong 已提交
87
        return [infer_img]
88

W
wangguanzhong 已提交
89
    images = set()
90 91 92 93 94 95
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
W
wangguanzhong 已提交
96 97
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)
98 99 100 101 102 103 104 105 106 107 108

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    logger.info("Found {} inference images in total.".format(len(images)))

    return images


def main():
    cfg = load_config(FLAGS.config)

    merge_config(FLAGS.opt)
109
    check_config(cfg)
110 111
    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
112 113 114 115
    # disable npu in config by default and check use_npu
    if 'use_npu' not in cfg:
        cfg.use_npu = False
    check_npu(cfg.use_npu)
W
wangguanzhong 已提交
116 117
    # check if paddlepaddle version is satisfied
    check_version()
118

119 120
    main_arch = cfg.architecture

121
    dataset = cfg.TestReader['dataset']
122 123

    test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
124
    dataset.set_images(test_images)
125

126 127 128 129 130 131
    if cfg.use_gpu:
        place = fluid.CUDAPlace(0)
    elif cfg.use_npu:
        place = fluid.NPUPlace(0)
    else:
        place = fluid.CPUPlace()
132 133 134 135 136 137 138 139
    exe = fluid.Executor(place)

    model = create(main_arch)

    startup_prog = fluid.Program()
    infer_prog = fluid.Program()
    with fluid.program_guard(infer_prog, startup_prog):
        with fluid.unique_name.guard():
140 141 142
            inputs_def = cfg['TestReader']['inputs_def']
            inputs_def['iterable'] = True
            feed_vars, loader = model.build_inputs(**inputs_def)
143 144 145
            test_fetches = model.test(feed_vars)
    infer_prog = infer_prog.clone(True)

146
    reader = create_reader(cfg.TestReader, devices_num=1)
W
wangguanzhong 已提交
147
    loader.set_sample_list_generator(reader, place)
148 149 150

    exe.run(startup_prog)
    if cfg.weights:
151
        checkpoint.load_params(exe, infer_prog, cfg.weights)
152 153

    # parse infer fetches
154
    assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \
155
            "unknown metric type {}".format(cfg.metric)
156
    extra_keys = []
157
    if cfg['metric'] in ['COCO', 'OID']:
158
        extra_keys = ['im_info', 'im_id', 'im_shape']
159
    if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE':
160
        extra_keys = ['im_id', 'im_shape']
161 162 163 164
    keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)

    # parse dataset category
    if cfg.metric == 'COCO':
G
Guanghua Yu 已提交
165
        from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info
166 167
    if cfg.metric == 'OID':
        from ppdet.utils.oid_eval import bbox2out, get_category_info
168 169
    if cfg.metric == "VOC":
        from ppdet.utils.voc_eval import bbox2out, get_category_info
170
    if cfg.metric == "WIDERFACE":
171
        from ppdet.utils.widerface_eval_utils import bbox2out, lmk2out, get_category_info
172

173 174 175 176
    anno_file = dataset.get_anno()
    with_background = dataset.with_background
    use_default_label = dataset.use_default_label

177 178 179
    clsid2catid, catid2name = get_category_info(anno_file, with_background,
                                                use_default_label)

180 181 182 183 184 185
    # whether output bbox is normalized in model output layer
    is_bbox_normalized = False
    if hasattr(model, 'is_bbox_normalized') and \
            callable(model.is_bbox_normalized):
        is_bbox_normalized = model.is_bbox_normalized()

走神的阿圆's avatar
走神的阿圆 已提交
186 187
    # use VisualDL to log image
    if FLAGS.use_vdl:
188
        assert six.PY3, "VisualDL requires Python >= 3.5"
走神的阿圆's avatar
走神的阿圆 已提交
189 190 191 192
        from visualdl import LogWriter
        vdl_writer = LogWriter(FLAGS.vdl_log_dir)
        vdl_image_step = 0
        vdl_image_frame = 0  # each frame can display ten pictures at most.
193

194
    imid2path = dataset.get_imid2path()
W
wangguanzhong 已提交
195
    for iter_id, data in enumerate(loader()):
196
        outs = exe.run(infer_prog,
W
wangguanzhong 已提交
197
                       feed=data,
198 199 200 201 202 203 204
                       fetch_list=values,
                       return_numpy=False)
        res = {
            k: (np.array(v), v.recursive_sequence_lengths())
            for k, v in zip(keys, outs)
        }
        logger.info('Infer iter {}'.format(iter_id))
W
wangguanzhong 已提交
205 206
        if 'TTFNet' in cfg.architecture:
            res['bbox'][1].append([len(res['bbox'][0])])
207 208 209 210
        if 'CornerNet' in cfg.architecture:
            from ppdet.utils.post_process import corner_post_process
            post_config = getattr(cfg, 'PostProcess', None)
            corner_post_process(res, post_config, cfg.num_classes)
211 212 213

        bbox_results = None
        mask_results = None
G
Guanghua Yu 已提交
214
        segm_results = None
215
        lmk_results = None
216 217 218 219 220
        if 'bbox' in res:
            bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
        if 'mask' in res:
            mask_results = mask2out([res], clsid2catid,
                                    model.mask_head.resolution)
G
Guanghua Yu 已提交
221 222
        if 'segm' in res:
            segm_results = segm2out([res], clsid2catid)
223 224
        if 'landmark' in res:
            lmk_results = lmk2out([res], is_bbox_normalized)
225 226 227 228 229 230

        # visualize result
        im_ids = res['im_id'][0]
        for im_id in im_ids:
            image_path = imid2path[int(im_id)]
            image = Image.open(image_path).convert('RGB')
231
            image = ImageOps.exif_transpose(image)
232

走神的阿圆's avatar
走神的阿圆 已提交
233 234
            # use VisualDL to log original image
            if FLAGS.use_vdl:
235
                original_image_np = np.array(image)
走神的阿圆's avatar
走神的阿圆 已提交
236 237 238
                vdl_writer.add_image(
                    "original/frame_{}".format(vdl_image_frame),
                    original_image_np, vdl_image_step)
239

240
            image = visualize_results(image,
J
jerrywgz 已提交
241 242
                                      int(im_id), catid2name,
                                      FLAGS.draw_threshold, bbox_results,
G
Guanghua Yu 已提交
243
                                      mask_results, segm_results, lmk_results)
W
wangguanzhong 已提交
244

走神的阿圆's avatar
走神的阿圆 已提交
245 246
            # use VisualDL to log image with bbox
            if FLAGS.use_vdl:
247
                infer_image_np = np.array(image)
走神的阿圆's avatar
走神的阿圆 已提交
248 249 250 251 252 253
                vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame),
                                     infer_image_np, vdl_image_step)
                vdl_image_step += 1
                if vdl_image_step % 10 == 0:
                    vdl_image_step = 0
                    vdl_image_frame += 1
254

255 256
            save_name = get_save_image_name(FLAGS.output_dir, image_path)
            logger.info("Detection bbox results save in {}".format(save_name))
J
jerrywgz 已提交
257
            image.save(save_name, quality=95)
258 259 260


if __name__ == '__main__':
261
    enable_static_mode()
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
    parser = ArgsParser()
    parser.add_argument(
        "--infer_dir",
        type=str,
        default=None,
        help="Directory for images to perform inference on.")
    parser.add_argument(
        "--infer_img",
        type=str,
        default=None,
        help="Image path, has higher priority over --infer_dir")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Directory for storing the output visualization files.")
J
jerrywgz 已提交
278 279 280 281 282
    parser.add_argument(
        "--draw_threshold",
        type=float,
        default=0.5,
        help="Threshold to reserve the result for visualization.")
283
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
284
        "--use_vdl",
285 286
        type=bool,
        default=False,
走神的阿圆's avatar
走神的阿圆 已提交
287
        help="whether to record the data to VisualDL.")
288
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
289
        '--vdl_log_dir',
290
        type=str,
走神的阿圆's avatar
走神的阿圆 已提交
291 292
        default="vdl_log_dir/image",
        help='VisualDL logging directory for image.')
293
    FLAGS = parser.parse_args()
W
wangguanzhong 已提交
294
    main()