infer.py 8.5 KB
Newer Older
R
ruri 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

R
root 已提交
15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
R
ruri 已提交
18

19
import os
20 21
import time
import sys
R
ruri 已提交
22 23 24 25
import math
import numpy as np
import argparse
import functools
26
import re
27
import logging
R
ruri 已提交
28

29 30
import paddle
import paddle.fluid as fluid
R
ruri 已提交
31
import reader
32
import models
33
import json
R
ruri 已提交
34
from utils import *
35 36 37

parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
38
add_arg = functools.partial(add_arguments, argparser=parser)
39
add_arg('data_dir',         str,  "./data/ILSVRC2012/val/", "The ImageNet data")
40 41
add_arg('use_gpu',          bool, True,                 "Whether to use GPU or not.")
add_arg('class_dim',        int,  1000,                 "Class number.")
R
ruri 已提交
42
parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model")
43
add_arg('model',            str,  "ResNet50",           "Set the network to use.")
R
ruri 已提交
44 45 46 47 48 49
add_arg('save_inference',   bool, False,                "Whether to save inference model or not")
add_arg('resize_short_size',int,  256,                  "Set resize short size")
add_arg('reader_thread',    int,  1,                    "The number of multi thread reader")
add_arg('reader_buf_size',  int,  2048,                 "The buf size of multi thread reader")
parser.add_argument('--image_mean', nargs='+', type=float, default=[0.485, 0.456, 0.406], help="The mean of input image data")
parser.add_argument('--image_std', nargs='+', type=float, default=[0.229, 0.224, 0.225], help="The std of input image data")
R
ruri 已提交
50
parser.add_argument('--image_shape', nargs='+', type=int, default=[3, 224, 224], help="the shape of image")
R
ruri 已提交
51
add_arg('topk',             int,  1,                    "topk")
52
add_arg('class_map_path',   str,  "./utils/tools/readable_label.txt", "readable label filepath")
53
add_arg('interpolation',    int,  None,                 "The interpolation mode")
54 55
add_arg('padding_type',     str,  "SAME",               "Padding type of convolution")
add_arg('use_se',           bool, True,                 "Whether to use Squeeze-and-Excitation module for EfficientNet.")
56
add_arg('image_path',       str,  None,                 "single image path")
57
add_arg('batch_size',       int,  8,                    "batch_size on all the devices")
58
add_arg('save_json_path',        str,  "test_res.json",            "save output to a json file")
59 60
# yapf: enable

61 62 63
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

64

65
def infer(args):
66
    model_list = [m for m in dir(models) if "__" not in m]
R
ruri 已提交
67
    assert args.model in model_list, "{} is not in lists: {}".format(args.model,
68
                                                                     model_list)
R
ruri 已提交
69 70
    assert os.path.isdir(args.pretrained_model
                         ), "please load right pretrained model path for infer"
71 72 73 74

    assert args.image_shape[
        1] <= args.resize_short_size, "Please check the args:image_shape and args:resize_short_size, The croped size(image_shape[1]) must smaller than or equal to the resized length(resize_short_size) "

75 76 77 78 79 80 81 82 83 84 85
    if args.image_path:
        assert os.path.isfile(
            args.image_path
        ), "Please check the args:image_path, it should be a path to single image."
        if args.use_gpu:
            assert fluid.core.get_cuda_device_count(
            ) == 1, "please set \"export CUDA_VISIBLE_DEVICES=\" available single card"
        else:
            assert int(os.environ.get('CPU_NUM',
                                      1)) == 1, "please set CPU_NUM as 1"

86
    image = fluid.data(
87
        name='image', shape=[None] + args.image_shape, dtype='float32')
88 89 90 91 92 93 94 95

    if args.model.startswith('EfficientNet'):
        model = models.__dict__[args.model](is_test=True,
                                            padding_type=args.padding_type,
                                            use_se=args.use_se)
    else:
        model = models.__dict__[args.model]()

R
ruri 已提交
96 97
    if args.model == "GoogLeNet":
        out, _, _ = model.net(input=image, class_dim=args.class_dim)
98
    else:
R
ruri 已提交
99
        out = model.net(input=image, class_dim=args.class_dim)
R
ruri 已提交
100
        out = fluid.layers.softmax(out)
101 102 103

    test_program = fluid.default_main_program().clone(for_test=True)

S
shippingwang 已提交
104
    fetch_list = [out.name]
105 106
    gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
    place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
107
    exe = fluid.Executor(place)
108
    exe.run(fluid.default_startup_program())
109 110
    if args.use_gpu:
        places = fluid.framework.cuda_places()
R
ruri 已提交
111 112
    else:
        places = fluid.framework.cpu_places()
113
    compiled_program = fluid.compiler.CompiledProgram(
114
        test_program).with_data_parallel(places=places)
115

R
ruri 已提交
116 117
    fluid.io.load_persistables(exe, args.pretrained_model)
    if args.save_inference:
118
        fluid.io.save_inference_model(
R
ruri 已提交
119
            dirname=args.model,
120 121 122 123 124 125
            feeded_var_names=['image'],
            main_program=test_program,
            target_vars=out,
            executor=exe,
            model_filename='model',
            params_filename='params')
126
        logger.info("model: {0} is already saved".format(args.model))
127
        exit(0)
128

129 130
    imagenet_reader = reader.ImageNetReader()
    test_reader = imagenet_reader.test(settings=args)
131
    feeder = fluid.DataFeeder(place=places, feed_list=[image])
132

R
ruri 已提交
133
    TOPK = args.topk
134
    if os.path.exists(args.class_map_path):
135 136
        logger.info(
            "The map of readable label and numerical label has been found!")
137 138 139 140 141 142 143 144 145 146 147 148 149 150
        with open(args.class_map_path) as f:
            label_dict = {}
            strinfo = re.compile(r"\d+ ")
            for item in f.readlines():
                key = item.split(" ")[0]
                value = [
                    strinfo.sub("", l).replace("\n", "")
                    for l in item.split(", ")
                ]
                label_dict[key] = value

    info = {}
    parallel_data = []
    parallel_id = []
R
ruri 已提交
151 152
    place_num = paddle.fluid.core.get_cuda_device_count(
    ) if args.use_gpu else int(os.environ.get('CPU_NUM', 1))
153 154 155
    if os.path.exists(args.save_json_path):
        logger.warning("path: {} Already exists! will recover it\n".format(
            args.save_json_path))
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
    with open(args.save_json_path, "w") as fout:
        for batch_id, data in enumerate(test_reader()):
            image_data = [[items[0]] for items in data]
            image_id = [items[1] for items in data]

            parallel_id.append(image_id)
            parallel_data.append(image_data)

            if place_num == len(parallel_data):
                result = exe.run(
                    compiled_program,
                    fetch_list=fetch_list,
                    feed=list(feeder.feed_parallel(parallel_data, place_num)))
                for i, res in enumerate(result[0]):
                    pred_label = np.argsort(res)[::-1][:TOPK]
                    real_id = str(np.array(parallel_id).flatten()[i])
                    _, real_id = os.path.split(real_id)

                    if os.path.exists(args.class_map_path):
                        readable_pred_label = []
                        for label in pred_label:
                            readable_pred_label.append(label_dict[str(label)])

                        info[real_id] = {}
                        info[real_id]['score'], info[real_id]['class'], info[
                            real_id]['class_name'] = str(res[pred_label]), str(
                                pred_label), readable_pred_label
                    else:
                        info[real_id] = {}
                        info[real_id]['score'], info[real_id]['class'] = str(
                            res[pred_label]), str(pred_label)

188
                    logger.info("{}, {}".format(real_id, info[real_id]))
189 190 191 192 193 194 195 196
                    sys.stdout.flush()
                    fout.write(real_id + "\t" + json.dumps(info[real_id]) +
                               "\n")

                parallel_data = []
                parallel_id = []

    os.remove(".tmp.txt")
197 198


199
def main():
200 201
    args = parser.parse_args()
    print_arguments(args)
R
ruri 已提交
202
    check_gpu()
203
    check_version()
204
    infer(args)
205 206 207 208


if __name__ == '__main__':
    main()