infer.py 4.3 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import print_function

import os
import argparse
import numpy as np
L
LielinJiang 已提交
21
from PIL import Image
D
dengkaipeng 已提交
22

23
import paddle
D
dengkaipeng 已提交
24 25
from paddle import fluid
from paddle.fluid.optimizer import Momentum
D
dengkaipeng 已提交
26
from paddle.io import DataLoader
D
dengkaipeng 已提交
27

D
dengkaipeng 已提交
28 29
from modeling import yolov3_darknet53, YoloLoss
from transforms import *
D
dengkaipeng 已提交
30
from utils import print_arguments
D
dengkaipeng 已提交
31 32 33 34 35 36 37
from visualizer import draw_bbox

import logging
logger = logging.getLogger(__name__)

IMAGE_MEAN = [0.485, 0.456, 0.406]
IMAGE_STD = [0.229, 0.224, 0.225]
38
NUM_MAX_BOXES = 50
D
dengkaipeng 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64


def get_save_image_name(output_dir, image_path):
    """
    Get save image name from source image path.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    image_name = os.path.split(image_path)[-1]
    name, ext = os.path.splitext(image_name)
    return os.path.join(output_dir, "{}".format(name)) + ext


def load_labels(label_list, with_background=True):
    idx = int(with_background)
    cat2name = {}
    with open(label_list) as f:
        for line in f.readlines():
            line = line.strip()
            if line:
                cat2name[idx] = line
                idx += 1
    return cat2name


def main():
65 66
    device = paddle.set_device(FLAGS.device)
    paddle.disable_static(device) if FLAGS.dynamic else None
D
dengkaipeng 已提交
67 68 69

    cat2name = load_labels(FLAGS.label_list, with_background=False)

L
LielinJiang 已提交
70 71
    model = yolov3_darknet53(
        num_classes=len(cat2name),
72
        num_max_boxes=NUM_MAX_BOXES,
L
LielinJiang 已提交
73 74
        model_mode='test',
        pretrained=FLAGS.weights is None)
D
dengkaipeng 已提交
75

76
    model.prepare()
D
dengkaipeng 已提交
77 78 79 80 81 82

    if FLAGS.weights is not None:
        model.load(FLAGS.weights, reset_optimizer=True)

    # image preprocess
    orig_img = Image.open(FLAGS.infer_image).convert('RGB')
L
LielinJiang 已提交
83
    w, h = orig_img.size
D
dengkaipeng 已提交
84 85 86 87 88
    img = orig_img.resize((608, 608), Image.BICUBIC)
    img = np.array(img).astype('float32') / 255.0
    img -= np.array(IMAGE_MEAN)
    img /= np.array(IMAGE_STD)
    img = img.transpose((2, 0, 1))[np.newaxis, :]
89 90
    img_id = np.array([0]).astype('int64')[np.newaxis, :]
    img_shape = np.array([h, w]).astype('int32')[np.newaxis, :]
D
dengkaipeng 已提交
91

D
dengkaipeng 已提交
92
    _, bboxes = model.test_batch([img_id, img_shape, img])
D
dengkaipeng 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106

    vis_img = draw_bbox(orig_img, cat2name, bboxes, FLAGS.draw_threshold)
    save_name = get_save_image_name(FLAGS.output_dir, FLAGS.infer_image)
    logger.info("Detection bbox results save in {}".format(save_name))
    vis_img.save(save_name, quality=95)


if __name__ == '__main__':
    parser = argparse.ArgumentParser("Yolov3 Training on VOC")
    parser.add_argument(
        "--device", type=str, default='gpu', help="device to use, gpu or cpu")
    parser.add_argument(
        "-d", "--dynamic", action='store_true', help="enable dygraph mode")
    parser.add_argument(
L
LielinJiang 已提交
107 108 109
        "--label_list",
        type=str,
        default=None,
D
dengkaipeng 已提交
110 111
        help="path to category label list file")
    parser.add_argument(
L
LielinJiang 已提交
112 113 114 115
        "-t",
        "--draw_threshold",
        type=float,
        default=0.5,
D
dengkaipeng 已提交
116 117
        help="threshold to reserve the result for visualization")
    parser.add_argument(
L
LielinJiang 已提交
118 119 120 121
        "-i",
        "--infer_image",
        type=str,
        default=None,
D
dengkaipeng 已提交
122 123
        help="image path for inference")
    parser.add_argument(
L
LielinJiang 已提交
124 125 126 127
        "-o",
        "--output_dir",
        type=str,
        default='output',
D
dengkaipeng 已提交
128 129
        help="directory to save inference result if --visualize is set")
    parser.add_argument(
L
LielinJiang 已提交
130 131 132 133
        "-w",
        "--weights",
        default=None,
        type=str,
D
dengkaipeng 已提交
134 135
        help="path to weights for inference")
    FLAGS = parser.parse_args()
D
dengkaipeng 已提交
136
    print_arguments(FLAGS)
D
dengkaipeng 已提交
137 138 139 140 141
    assert os.path.isfile(FLAGS.infer_image), \
            "infer_image {} not a file".format(FLAGS.infer_image)
    assert os.path.isfile(FLAGS.label_list), \
            "label_list {} not a file".format(FLAGS.label_list)
    main()