未验证 提交 27013344 编写于 作者: mikeshi1980's avatar mikeshi1980 提交者: GitHub

添加对一个目录的文件的支持 (#427)

上级 18d27608
......@@ -30,6 +30,7 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-d", "--image_dir", type=str)
parser.add_argument("-m", "--model_file", type=str)
parser.add_argument("-p", "--params_file", type=str)
parser.add_argument("-b", "--batch_size", type=int, default=1)
......@@ -93,6 +94,8 @@ def preprocess(fname, ops):
def main():
import os
args = parse_args()
if not args.enable_benchmark:
......@@ -102,6 +105,8 @@ def main():
assert args.use_gpu is True
assert args.model_name is not None
assert args.use_tensorrt is True
assert args.image_file is not None
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
......@@ -118,20 +123,30 @@ def main():
test_num = 500
test_time = 0.0
if not args.enable_benchmark:
inputs = preprocess(args.image_file, operators)
inputs = np.expand_dims(
inputs, axis=0).repeat(
args.batch_size, axis=0).copy()
input_tensor.copy_from_cpu(inputs)
predictor.zero_copy_run()
output = output_tensor.copy_to_cpu()
output = output.flatten()
cls = np.argmax(output)
score = output[cls]
logger.info("class: {0}".format(cls))
logger.info("score: {0}".format(score))
image_files = []
if args.image_file is not None:
image_files = [args.image_file]
elif args.image_dir is not None:
supported_exts = ('.jpg', 'jpeg', '.png', '.gif', '.bmp')
for root, _, files in os.walk(args.image_dir, topdown=False):
image_files += [os.path.join(root, f) for f in files
if os.path.splitext(f)[-1].lower() in supported_exts]
for image_file in image_files:
inputs = preprocess(image_file, operators)
inputs = np.expand_dims(
inputs, axis=0).repeat(
args.batch_size, axis=0).copy()
input_tensor.copy_from_cpu(inputs)
predictor.zero_copy_run()
output = output_tensor.copy_to_cpu()
output = output.flatten()
cls = np.argmax(output)
score = output[cls]
logger.info("image file: {0}".format(image_file))
logger.info("class: {0}".format(cls))
logger.info("score: {0}".format(score))
else:
for i in range(0, test_num + 10):
inputs = np.random.rand(args.batch_size, 3, 224,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册