diff --git a/tools/infer/predict.py b/tools/infer/predict.py index c78cb1b9e2162d15bf062bd207f4c87a0b0880e3..4ef4af5c1032defecac15073c1bcdc96f0a22f24 100644 --- a/tools/infer/predict.py +++ b/tools/infer/predict.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import utils import argparse import numpy as np @@ -24,6 +23,7 @@ from paddle.fluid.core import create_paddle_predictor logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def parse_args(): def str2bool(v): return v.lower() in ("true", "t", "1") @@ -47,19 +47,18 @@ def parse_args(): def create_predictor(args): config = AnalysisConfig(args.model_file, args.params_file) - - if args.use_gpu: config.enable_use_gpu(args.gpu_mem, 0) else: config.disable_gpu() config.disable_glog_info() - config.switch_ir_optim(args.ir_optim) # default true + config.switch_ir_optim(args.ir_optim) # default true if args.use_tensorrt: config.enable_tensorrt_engine( - precision_mode=AnalysisConfig.Precision.Half if args.use_fp16 else AnalysisConfig.Precision.Float32, - max_batch_size=args.batch_size) + precision_mode=AnalysisConfig.Precision.Half + if args.use_fp16 else AnalysisConfig.Precision.Float32, + max_batch_size=args.batch_size) config.enable_memory_optim() # use zero copy @@ -79,7 +78,7 @@ def create_operators(): resize_op = utils.ResizeImage(resize_short=256) crop_op = utils.CropImage(size=(size, size)) normalize_op = utils.NormalizeImage( - scale=img_scale, mean=img_mean, std=img_std) + scale=img_scale, mean=img_mean, std=img_std) totensor_op = utils.ToTensor() return [decode_op, resize_op, crop_op, normalize_op, totensor_op] @@ -104,38 +103,62 @@ def main(): assert args.model_name is not None assert args.use_tensorrt == True # HALF precission predict only work when using tensorrt - if args.use_fp16==True: + if args.use_fp16 == True: assert args.use_tensorrt == True operators = create_operators() predictor = create_predictor(args) inputs = preprocess(args.image_file, operators) - inputs = np.expand_dims(inputs, axis=0).repeat(args.batch_size, axis=0).copy() + inputs = np.expand_dims( + inputs, axis=0).repeat( + args.batch_size, axis=0).copy() input_names = predictor.get_input_names() input_tensor = predictor.get_input_tensor(input_names[0]) - input_tensor.copy_from_cpu(inputs) + + output_names = predictor.get_output_names() + output_tensor = predictor.get_output_tensor(output_names[0]) + + test_num = 500 + test_time = 0.0 if not args.enable_benchmark: + inputs = preprocess(args.image_file, operators) + inputs = np.expand_dims( + inputs, axis=0).repeat( + args.batch_size, axis=0).copy() + input_tensor.copy_from_cpu(inputs) + predictor.zero_copy_run() + + output = output_tensor.copy_to_cpu() + output = output.flatten() + cls = np.argmax(output) + score = output[cls] + logger.info("class: {0}".format(cls)) + logger.info("score: {0}".format(score)) else: - for i in range(0,1010): - if i == 10: - start = time.time() + for i in range(0, test_num + 10): + inputs = np.random.rand(args.batch_size, 3, 224, + 224).astype(np.float32) + start_time = time.time() + input_tensor.copy_from_cpu(inputs) + predictor.zero_copy_run() - end = time.time() - fp_message = "FP16" if args.use_fp16 else "FP32" - logger.info("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format(args.model_name, fp_message, args.batch_size, end-start)) + output = output_tensor.copy_to_cpu() + output = output.flatten() + if i >= 10: + test_time += time.time() - start_time + cls = np.argmax(output) + score = output[cls] + logger.info("class: {0}".format(cls)) + logger.info("score: {0}".format(score)) - output_names = predictor.get_output_names() - output_tensor = predictor.get_output_tensor(output_names[0]) - output = output_tensor.copy_to_cpu() - output = output.flatten() - cls = np.argmax(output) - score = output[cls] - logger.info("class: {0}".format(cls)) - logger.info("score: {0}".format(score)) + fp_message = "FP16" if args.use_fp16 else "FP32" + logger.info("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( + args.model_name, fp_message, args.batch_size, 1000 * test_time / + test_num)) if __name__ == "__main__":