From 909c544b4b79ce693428ba92d17068d9d1050ac5 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 22 Jul 2020 09:59:22 +0000 Subject: [PATCH] add support for infer dir --- tools/infer/infer.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/tools/infer/infer.py b/tools/infer/infer.py index 0962bcca..83c1fb4b 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import utils import argparse import numpy as np @@ -99,21 +100,40 @@ def postprocess(outputs, topk=5): return zip(index, prob[index]) +def get_image_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + return imgs_lists + + def main(): args = parse_args() operators = create_operators() exe, program, feed_names, fetch_names = create_predictor(args) - data = preprocess(args.image_file, operators) - data = np.expand_dims(data, axis=0) - outputs = exe.run(program, - feed={feed_names[0]: data}, - fetch_list=fetch_names, - return_numpy=False) - probs = postprocess(outputs) - - for idx, prob in probs: - print("class id: {:d}, probability: {:.4f}".format(idx, prob)) + image_list = get_image_list(args.image_file) + for idx, filename in enumerate(image_list): + data = preprocess(filename, operators) + data = np.expand_dims(data, axis=0) + outputs = exe.run(program, + feed={feed_names[0]: data}, + fetch_list=fetch_names, + return_numpy=False) + probs = postprocess(outputs) + print("current image: {}".format(filename)) + for idx, prob in probs: + print("\tclass id: {:d}, probability: {:.4f}".format(idx, prob)) if __name__ == "__main__": -- GitLab