From 27013344cb54f93710bd2834069ce49fe211ac0a Mon Sep 17 00:00:00 2001 From: Mike Shi Date: Tue, 1 Dec 2020 01:31:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E7=9B=AE=E5=BD=95=E7=9A=84=E6=96=87=E4=BB=B6=E7=9A=84=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20(#427)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/predict.py | 43 ++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/tools/infer/predict.py b/tools/infer/predict.py index 26b7a6cd..17a5157c 100644 --- a/tools/infer/predict.py +++ b/tools/infer/predict.py @@ -30,6 +30,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("-i", "--image_file", type=str) + parser.add_argument("-d", "--image_dir", type=str) parser.add_argument("-m", "--model_file", type=str) parser.add_argument("-p", "--params_file", type=str) parser.add_argument("-b", "--batch_size", type=int, default=1) @@ -93,6 +94,8 @@ def preprocess(fname, ops): def main(): + import os + args = parse_args() if not args.enable_benchmark: @@ -102,6 +105,8 @@ def main(): assert args.use_gpu is True assert args.model_name is not None assert args.use_tensorrt is True + assert args.image_file is not None + # HALF precission predict only work when using tensorrt if args.use_fp16 is True: assert args.use_tensorrt is True @@ -118,20 +123,30 @@ def main(): test_num = 500 test_time = 0.0 if not args.enable_benchmark: - inputs = preprocess(args.image_file, operators) - inputs = np.expand_dims( - inputs, axis=0).repeat( - args.batch_size, axis=0).copy() - input_tensor.copy_from_cpu(inputs) - - predictor.zero_copy_run() - - output = output_tensor.copy_to_cpu() - output = output.flatten() - cls = np.argmax(output) - score = output[cls] - logger.info("class: {0}".format(cls)) - logger.info("score: {0}".format(score)) + image_files = [] + if args.image_file is not None: + image_files = [args.image_file] + elif args.image_dir is not None: + supported_exts = ('.jpg', 'jpeg', '.png', '.gif', '.bmp') + for root, _, files in os.walk(args.image_dir, topdown=False): + image_files += [os.path.join(root, f) for f in files + if os.path.splitext(f)[-1].lower() in supported_exts] + for image_file in image_files: + inputs = preprocess(image_file, operators) + inputs = np.expand_dims( + inputs, axis=0).repeat( + args.batch_size, axis=0).copy() + input_tensor.copy_from_cpu(inputs) + + predictor.zero_copy_run() + + output = output_tensor.copy_to_cpu() + output = output.flatten() + cls = np.argmax(output) + score = output[cls] + logger.info("image file: {0}".format(image_file)) + logger.info("class: {0}".format(cls)) + logger.info("score: {0}".format(score)) else: for i in range(0, test_num + 10): inputs = np.random.rand(args.batch_size, 3, 224, -- GitLab