diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 1bb34990f2000e7294f8ecbda7da2fb6f19b3336..72b19b9198163751423f5b34880ddc0acc6fbbdf 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -84,19 +84,29 @@ def parse_args(): parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--use_zero_copy_run", type=str2bool, default=False) - + parser.add_argument("--use_pdserving", type=str2bool, default=False) - + return parser.parse_args() def create_predictor(args, mode): + """ + create predictor for inference + :param args: params for prediction engine + :param mode: mode + :return: predictor + """ if mode == "det": model_dir = args.det_model_dir elif mode == 'cls': model_dir = args.cls_model_dir - else: + elif mode == 'rec': model_dir = args.rec_model_dir + else: + raise ValueError( + "'mode' of create_predictor() can only be one of ['det', 'cls', 'rec']" + ) if model_dir is None: logger.info("not find {} model file path {}".format(mode, model_dir)) @@ -144,6 +154,12 @@ def create_predictor(args, mode): def draw_text_det_res(dt_boxes, img_path): + """ + Visualize the results of detection + :param dt_boxes: The boxes predicted by detection model + :param img_path: Image path + :return: Visualized image + """ src_im = cv2.imread(img_path) for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2)