infer.py 10.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

M
Manuel Garcia 已提交
19 20 21
import os
import sys

22 23 24 25
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
    sys.path.append(parent_path)
26

Q
qingqing01 已提交
27
import glob
28
import numpy as np
29
import six
30
from PIL import Image, ImageOps
31 32 33 34 35 36 37 38

from paddle import fluid

import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)

K
Kaipeng Deng 已提交
39 40 41 42 43
try:
    from ppdet.core.workspace import load_config, merge_config, create

    from ppdet.utils.eval_utils import parse_fetches
    from ppdet.utils.cli import ArgsParser
H
houj04 已提交
44
    from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config, enable_static_mode
K
Kaipeng Deng 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    from ppdet.utils.visualizer import visualize_results
    import ppdet.utils.checkpoint as checkpoint

    from ppdet.data.reader import create_reader
except ImportError as e:
    if sys.argv[0].find('static') >= 0:
        logger.error("Importing ppdet failed when running static model "
                     "with error: {}\n"
                     "please try:\n"
                     "\t1. run static model under PaddleDetection/static "
                     "directory\n"
                     "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
                     "dynamic version firstly.".format(e))
        sys.exit(-1)
    else:
        raise e

62 63 64 65 66 67 68

def get_save_image_name(output_dir, image_path):
    """
    Get save image name from source image path.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
K
Kaipeng Deng 已提交
69
    image_name = os.path.split(image_path)[-1]
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    name, ext = os.path.splitext(image_name)
    return os.path.join(output_dir, "{}".format(name)) + ext


def get_test_images(infer_dir, infer_img):
    """
    Get image path list in TEST mode
    """
    assert infer_img is not None or infer_dir is not None, \
        "--infer_img or --infer_dir should be set"
    assert infer_img is None or os.path.isfile(infer_img), \
            "{} is not a file".format(infer_img)
    assert infer_dir is None or os.path.isdir(infer_dir), \
            "{} is not a directory".format(infer_dir)

    # infer_img has a higher priority
    if infer_img and os.path.isfile(infer_img):
W
wangguanzhong 已提交
87
        return [infer_img]
88

W
wangguanzhong 已提交
89
    images = set()
90 91 92 93 94 95
    infer_dir = os.path.abspath(infer_dir)
    assert os.path.isdir(infer_dir), \
        "infer_dir {} is not a directory".format(infer_dir)
    exts = ['jpg', 'jpeg', 'png', 'bmp']
    exts += [ext.upper() for ext in exts]
    for ext in exts:
W
wangguanzhong 已提交
96 97
        images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
    images = list(images)
98 99 100 101 102 103 104 105 106 107 108

    assert len(images) > 0, "no image found in {}".format(infer_dir)
    logger.info("Found {} inference images in total.".format(len(images)))

    return images


def main():
    cfg = load_config(FLAGS.config)

    merge_config(FLAGS.opt)
109
    check_config(cfg)
110 111
    # check if set use_gpu=True in paddlepaddle cpu version
    check_gpu(cfg.use_gpu)
112 113 114 115
    # disable npu in config by default and check use_npu
    if 'use_npu' not in cfg:
        cfg.use_npu = False
    check_npu(cfg.use_npu)
H
houj04 已提交
116 117 118 119
    # disable xpu in config by default and check use_xpu
    if 'use_xpu' not in cfg:
        cfg.use_xpu = False
    check_xpu(cfg.use_xpu)
W
wangguanzhong 已提交
120 121
    # check if paddlepaddle version is satisfied
    check_version()
122

123 124
    main_arch = cfg.architecture

125
    dataset = cfg.TestReader['dataset']
126 127

    test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
128
    dataset.set_images(test_images)
129

130 131 132 133
    if cfg.use_gpu:
        place = fluid.CUDAPlace(0)
    elif cfg.use_npu:
        place = fluid.NPUPlace(0)
H
houj04 已提交
134 135
    elif cfg.use_xpu:
        place = fluid.XPUPlace(0)
136 137
    else:
        place = fluid.CPUPlace()
138 139 140 141 142 143 144 145
    exe = fluid.Executor(place)

    model = create(main_arch)

    startup_prog = fluid.Program()
    infer_prog = fluid.Program()
    with fluid.program_guard(infer_prog, startup_prog):
        with fluid.unique_name.guard():
146 147 148
            inputs_def = cfg['TestReader']['inputs_def']
            inputs_def['iterable'] = True
            feed_vars, loader = model.build_inputs(**inputs_def)
149 150 151
            test_fetches = model.test(feed_vars)
    infer_prog = infer_prog.clone(True)

152
    reader = create_reader(cfg.TestReader, devices_num=1)
W
wangguanzhong 已提交
153
    loader.set_sample_list_generator(reader, place)
154 155 156

    exe.run(startup_prog)
    if cfg.weights:
157
        checkpoint.load_params(exe, infer_prog, cfg.weights)
158 159

    # parse infer fetches
160
    assert cfg.metric in ['COCO', 'VOC', 'OID', 'WIDERFACE'], \
161
            "unknown metric type {}".format(cfg.metric)
162
    extra_keys = []
163
    if cfg['metric'] in ['COCO', 'OID']:
164
        extra_keys = ['im_info', 'im_id', 'im_shape']
165
    if cfg['metric'] == 'VOC' or cfg['metric'] == 'WIDERFACE':
166
        extra_keys = ['im_id', 'im_shape']
167 168 169 170
    keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)

    # parse dataset category
    if cfg.metric == 'COCO':
G
Guanghua Yu 已提交
171
        from ppdet.utils.coco_eval import bbox2out, mask2out, segm2out, get_category_info
172 173
    if cfg.metric == 'OID':
        from ppdet.utils.oid_eval import bbox2out, get_category_info
174 175
    if cfg.metric == "VOC":
        from ppdet.utils.voc_eval import bbox2out, get_category_info
176
    if cfg.metric == "WIDERFACE":
177
        from ppdet.utils.widerface_eval_utils import bbox2out, lmk2out, get_category_info
178

179 180 181 182
    anno_file = dataset.get_anno()
    with_background = dataset.with_background
    use_default_label = dataset.use_default_label

183 184 185
    clsid2catid, catid2name = get_category_info(anno_file, with_background,
                                                use_default_label)

186 187 188 189 190 191
    # whether output bbox is normalized in model output layer
    is_bbox_normalized = False
    if hasattr(model, 'is_bbox_normalized') and \
            callable(model.is_bbox_normalized):
        is_bbox_normalized = model.is_bbox_normalized()

走神的阿圆's avatar
走神的阿圆 已提交
192 193
    # use VisualDL to log image
    if FLAGS.use_vdl:
194
        assert six.PY3, "VisualDL requires Python >= 3.5"
走神的阿圆's avatar
走神的阿圆 已提交
195 196 197 198
        from visualdl import LogWriter
        vdl_writer = LogWriter(FLAGS.vdl_log_dir)
        vdl_image_step = 0
        vdl_image_frame = 0  # each frame can display ten pictures at most.
199

200
    imid2path = dataset.get_imid2path()
W
wangguanzhong 已提交
201
    for iter_id, data in enumerate(loader()):
202
        outs = exe.run(infer_prog,
W
wangguanzhong 已提交
203
                       feed=data,
204 205 206 207 208 209 210
                       fetch_list=values,
                       return_numpy=False)
        res = {
            k: (np.array(v), v.recursive_sequence_lengths())
            for k, v in zip(keys, outs)
        }
        logger.info('Infer iter {}'.format(iter_id))
W
wangguanzhong 已提交
211 212
        if 'TTFNet' in cfg.architecture:
            res['bbox'][1].append([len(res['bbox'][0])])
213 214 215 216
        if 'CornerNet' in cfg.architecture:
            from ppdet.utils.post_process import corner_post_process
            post_config = getattr(cfg, 'PostProcess', None)
            corner_post_process(res, post_config, cfg.num_classes)
217 218 219

        bbox_results = None
        mask_results = None
G
Guanghua Yu 已提交
220
        segm_results = None
221
        lmk_results = None
222 223 224 225 226
        if 'bbox' in res:
            bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
        if 'mask' in res:
            mask_results = mask2out([res], clsid2catid,
                                    model.mask_head.resolution)
G
Guanghua Yu 已提交
227 228
        if 'segm' in res:
            segm_results = segm2out([res], clsid2catid)
229 230
        if 'landmark' in res:
            lmk_results = lmk2out([res], is_bbox_normalized)
231 232 233 234 235 236

        # visualize result
        im_ids = res['im_id'][0]
        for im_id in im_ids:
            image_path = imid2path[int(im_id)]
            image = Image.open(image_path).convert('RGB')
237
            image = ImageOps.exif_transpose(image)
238

走神的阿圆's avatar
走神的阿圆 已提交
239 240
            # use VisualDL to log original image
            if FLAGS.use_vdl:
241
                original_image_np = np.array(image)
走神的阿圆's avatar
走神的阿圆 已提交
242 243 244
                vdl_writer.add_image(
                    "original/frame_{}".format(vdl_image_frame),
                    original_image_np, vdl_image_step)
245

246
            image = visualize_results(image,
J
jerrywgz 已提交
247 248
                                      int(im_id), catid2name,
                                      FLAGS.draw_threshold, bbox_results,
G
Guanghua Yu 已提交
249
                                      mask_results, segm_results, lmk_results)
W
wangguanzhong 已提交
250

走神的阿圆's avatar
走神的阿圆 已提交
251 252
            # use VisualDL to log image with bbox
            if FLAGS.use_vdl:
253
                infer_image_np = np.array(image)
走神的阿圆's avatar
走神的阿圆 已提交
254 255 256 257 258 259
                vdl_writer.add_image("bbox/frame_{}".format(vdl_image_frame),
                                     infer_image_np, vdl_image_step)
                vdl_image_step += 1
                if vdl_image_step % 10 == 0:
                    vdl_image_step = 0
                    vdl_image_frame += 1
260

261 262
            save_name = get_save_image_name(FLAGS.output_dir, image_path)
            logger.info("Detection bbox results save in {}".format(save_name))
J
jerrywgz 已提交
263
            image.save(save_name, quality=95)
264 265 266


if __name__ == '__main__':
267
    enable_static_mode()
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
    parser = ArgsParser()
    parser.add_argument(
        "--infer_dir",
        type=str,
        default=None,
        help="Directory for images to perform inference on.")
    parser.add_argument(
        "--infer_img",
        type=str,
        default=None,
        help="Image path, has higher priority over --infer_dir")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output",
        help="Directory for storing the output visualization files.")
J
jerrywgz 已提交
284 285 286 287 288
    parser.add_argument(
        "--draw_threshold",
        type=float,
        default=0.5,
        help="Threshold to reserve the result for visualization.")
289
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
290
        "--use_vdl",
291 292
        type=bool,
        default=False,
走神的阿圆's avatar
走神的阿圆 已提交
293
        help="whether to record the data to VisualDL.")
294
    parser.add_argument(
走神的阿圆's avatar
走神的阿圆 已提交
295
        '--vdl_log_dir',
296
        type=str,
走神的阿圆's avatar
走神的阿圆 已提交
297 298
        default="vdl_log_dir/image",
        help='VisualDL logging directory for image.')
299
    FLAGS = parser.parse_args()
W
wangguanzhong 已提交
300
    main()