提交 a958e356 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix benchmark

上级 29e0bd01
...@@ -15,13 +15,10 @@ ...@@ -15,13 +15,10 @@
import argparse import argparse
import utils import utils
import numpy as np import numpy as np
import logging
import time import time
from paddle.fluid.core import AnalysisConfig from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor from paddle.fluid.core import create_paddle_predictor
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def parse_args(): def parse_args():
...@@ -101,7 +98,6 @@ def main(): ...@@ -101,7 +98,6 @@ def main():
else: else:
assert args.use_gpu is True assert args.use_gpu is True
assert args.model_name is not None assert args.model_name is not None
assert args.use_tensorrt is True
# HALF precission predict only work when using tensorrt # HALF precission predict only work when using tensorrt
if args.use_fp16 is True: if args.use_fp16 is True:
assert args.use_tensorrt is True assert args.use_tensorrt is True
...@@ -130,8 +126,9 @@ def main(): ...@@ -130,8 +126,9 @@ def main():
output = output.flatten() output = output.flatten()
cls = np.argmax(output) cls = np.argmax(output)
score = output[cls] score = output[cls]
logger.info("class: {0}".format(cls)) print("Current image file: {}".format(args.image_file))
logger.info("score: {0}".format(score)) print("\ttop-1 class: {0}".format(cls))
print("\ttop-1 score: {0}".format(score))
else: else:
for i in range(0, test_num + 10): for i in range(0, test_num + 10):
inputs = np.random.rand(args.batch_size, 3, 224, inputs = np.random.rand(args.batch_size, 3, 224,
...@@ -145,11 +142,13 @@ def main(): ...@@ -145,11 +142,13 @@ def main():
output = output.flatten() output = output.flatten()
if i >= 10: if i >= 10:
test_time += time.time() - start_time test_time += time.time() - start_time
time.sleep(0.01) # sleep for T4 GPU
fp_message = "FP16" if args.use_fp16 else "FP32" fp_message = "FP16" if args.use_fp16 else "FP32"
logger.info("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
args.model_name, fp_message, args.batch_size, 1000 * test_time / print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format(
test_num)) args.model_name, trt_msg, fp_message, args.batch_size, 1000 *
test_time / test_num))
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册