未验证 提交 27013344 编写于 作者: mikeshi1980's avatar mikeshi1980 提交者: GitHub

添加对一个目录的文件的支持 (#427)

上级 18d27608
...@@ -30,6 +30,7 @@ def parse_args(): ...@@ -30,6 +30,7 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str) parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-d", "--image_dir", type=str)
parser.add_argument("-m", "--model_file", type=str) parser.add_argument("-m", "--model_file", type=str)
parser.add_argument("-p", "--params_file", type=str) parser.add_argument("-p", "--params_file", type=str)
parser.add_argument("-b", "--batch_size", type=int, default=1) parser.add_argument("-b", "--batch_size", type=int, default=1)
...@@ -93,6 +94,8 @@ def preprocess(fname, ops): ...@@ -93,6 +94,8 @@ def preprocess(fname, ops):
def main(): def main():
import os
args = parse_args() args = parse_args()
if not args.enable_benchmark: if not args.enable_benchmark:
...@@ -102,6 +105,8 @@ def main(): ...@@ -102,6 +105,8 @@ def main():
assert args.use_gpu is True assert args.use_gpu is True
assert args.model_name is not None assert args.model_name is not None
assert args.use_tensorrt is True assert args.use_tensorrt is True
assert args.image_file is not None
# HALF precission predict only work when using tensorrt # HALF precission predict only work when using tensorrt
if args.use_fp16 is True: if args.use_fp16 is True:
assert args.use_tensorrt is True assert args.use_tensorrt is True
...@@ -118,20 +123,30 @@ def main(): ...@@ -118,20 +123,30 @@ def main():
test_num = 500 test_num = 500
test_time = 0.0 test_time = 0.0
if not args.enable_benchmark: if not args.enable_benchmark:
inputs = preprocess(args.image_file, operators) image_files = []
inputs = np.expand_dims( if args.image_file is not None:
inputs, axis=0).repeat( image_files = [args.image_file]
args.batch_size, axis=0).copy() elif args.image_dir is not None:
input_tensor.copy_from_cpu(inputs) supported_exts = ('.jpg', 'jpeg', '.png', '.gif', '.bmp')
for root, _, files in os.walk(args.image_dir, topdown=False):
predictor.zero_copy_run() image_files += [os.path.join(root, f) for f in files
if os.path.splitext(f)[-1].lower() in supported_exts]
output = output_tensor.copy_to_cpu() for image_file in image_files:
output = output.flatten() inputs = preprocess(image_file, operators)
cls = np.argmax(output) inputs = np.expand_dims(
score = output[cls] inputs, axis=0).repeat(
logger.info("class: {0}".format(cls)) args.batch_size, axis=0).copy()
logger.info("score: {0}".format(score)) input_tensor.copy_from_cpu(inputs)
predictor.zero_copy_run()
output = output_tensor.copy_to_cpu()
output = output.flatten()
cls = np.argmax(output)
score = output[cls]
logger.info("image file: {0}".format(image_file))
logger.info("class: {0}".format(cls))
logger.info("score: {0}".format(score))
else: else:
for i in range(0, test_num + 10): for i in range(0, test_num + 10):
inputs = np.random.rand(args.batch_size, 3, 224, inputs = np.random.rand(args.batch_size, 3, 224,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册