From 60711bafe2799f05936eb1685108a093903833ca Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Thu, 22 Oct 2020 18:24:42 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=9D=99=E6=80=81=E5=9B=BE?= =?UTF-8?q?=E7=9A=84create=5Fpredictor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/utility.py | 56 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index dab06349..6e49357b 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -14,11 +14,14 @@ import argparse import os +import sys import cv2 import numpy as np import json from PIL import Image, ImageDraw, ImageFont import math +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import create_paddle_predictor def parse_args(): @@ -71,6 +74,59 @@ def parse_args(): return parser.parse_args() +def create_predictor(args, mode, logger): + if mode == "det": + model_dir = args.det_model_dir + elif mode == 'cls': + model_dir = args.cls_model_dir + else: + model_dir = args.rec_model_dir + + if model_dir is None: + logger.info("not find {} model file path {}".format(mode, model_dir)) + sys.exit(0) + model_file_path = model_dir + "/__model__" + params_file_path = model_dir + "/__variables__" + if not os.path.exists(model_file_path): + logger.info("not find model file path {}".format(model_file_path)) + sys.exit(0) + if not os.path.exists(params_file_path): + logger.info("not find params file path {}".format(params_file_path)) + sys.exit(0) + + config = AnalysisConfig(model_file_path, params_file_path) + + if args.use_gpu: + config.enable_use_gpu(args.gpu_mem, 0) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(6) + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + + # config.enable_memory_optim() + config.disable_glog_info() + + if args.use_zero_copy_run: + config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + config.switch_use_feed_fetch_ops(False) + else: + config.switch_use_feed_fetch_ops(True) + + predictor = create_paddle_predictor(config) + input_names = predictor.get_input_names() + for name in input_names: + input_tensor = predictor.get_input_tensor(name) + output_names = predictor.get_output_names() + output_tensors = [] + for output_name in output_names: + output_tensor = predictor.get_output_tensor(output_name) + output_tensors.append(output_tensor) + return predictor, input_tensor, output_tensors + + def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: -- GitLab